diff --git a/Cargo.lock b/Cargo.lock index 84c4ad45..41c5f51c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2292,16 +2292,6 @@ dependencies = [ "cipher", ] -[[package]] -name = "ctrlc" -version = "3.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697b5419f348fd5ae2478e8018cb016c00a5881c7f46c717de98ffd135a5651c" -dependencies = [ - "nix 0.29.0", - "windows-sys 0.59.0", -] - [[package]] name = "curve25519-dalek" version = "4.1.3" @@ -3578,7 +3568,7 @@ dependencies = [ "futures-channel", "futures-io", "futures-util", - "idna 1.0.3", + "idna", "ipnet", "once_cell", "rand 0.8.5", @@ -3603,7 +3593,7 @@ dependencies = [ "futures-channel", "futures-io", "futures-util", - "idna 1.0.3", + "idna", "ipnet", "once_cell", "rand 0.9.1", @@ -4077,16 +4067,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" -[[package]] -name = "idna" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - [[package]] name = "idna" version = "1.0.3" @@ -4141,12 +4121,6 @@ dependencies = [ "windows 0.52.0", ] -[[package]] -name = "if_chain" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" - [[package]] name = "igd-next" version = "0.14.3" @@ -4242,17 +4216,10 @@ dependencies = [ ] [[package]] -name = "indicatif" -version = "0.17.11" +name = "indoc" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" -dependencies = [ - "console", - "number_prefix", - "portable-atomic", - "unicode-width", - "web-time", -] +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "inout" @@ -5528,6 +5495,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.17" @@ -5923,7 +5899,7 @@ dependencies = [ "bitflags 1.3.2", "cfg-if", "libc", - "memoffset", + "memoffset 0.7.1", "pin-utils", ] @@ -6077,12 +6053,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" - [[package]] name = "nvml-wrapper" version = "0.10.0" @@ -6211,8 +6181,6 @@ dependencies = [ "actix-web-prometheus", "alloy", "anyhow", - "async-trait", - "base64 0.22.1", "chrono", "clap", "env_logger", @@ -6220,11 +6188,10 @@ dependencies = [ "google-cloud-auth 0.18.0", "google-cloud-storage", "hex", - "iroh", "log", "mockito", + "p2p", "prometheus 0.14.0", - "rand 0.8.5", "rand 0.9.1", "redis", "redis-test", @@ -6233,6 +6200,7 @@ dependencies = [ "serde_json", "shared", "tokio", + "tokio-util", "url", "utoipa", "utoipa-swagger-ui", @@ -6257,6 +6225,21 @@ dependencies = [ "sha2", ] +[[package]] +name = "p2p" +version = "0.3.11" +dependencies = [ + "anyhow", + "libp2p", + "log", + "nalgebra", + "serde", + "tokio", + "tokio-util", + "tracing", + "void", +] + [[package]] name = "parity-scale-codec" version = "3.7.4" @@ -6713,6 +6696,51 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "prime-core" +version = "0.1.0" +dependencies = [ + "actix-web", + "alloy", + "alloy-provider", + "anyhow", + "env_logger", + "futures-util", + "hex", + "log", + "rand 0.8.5", + "redis", + "serde", + "serde_json", + "shared", + "subtle", + "tokio", + "tokio-util", + "url", + "uuid", +] + +[[package]] +name = "prime-protocol-py" +version = "0.1.0" +dependencies = [ + "alloy", + "alloy-provider", + "log", + "prime-core", + "pyo3", + "pyo3-log", + "pythonize", + "serde", + "serde_json", + "shared", + "test-log", + "thiserror 1.0.69", + "tokio", + "tokio-test", + "url", +] + [[package]] name = "primeorder" version = "0.13.6" @@ -6742,30 +6770,6 @@ dependencies = [ "toml_edit", ] -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -6921,6 +6925,89 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "pyo3" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" +dependencies = [ + "indoc", + "libc", + "memoffset 0.9.1", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + +[[package]] +name = "pyo3-macros" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "pythonize" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597907139a488b22573158793aa7539df36ae863eba300c75f3a0d65fc475e27" +dependencies = [ + "pyo3", + "serde", +] + [[package]] name = "quanta" version = "0.10.1" @@ -8067,15 +8154,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "serde_spanned" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" -dependencies = [ - "serde", -] - [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -8219,12 +8297,14 @@ dependencies = [ "base64 0.22.1", "chrono", "dashmap", + "futures", "futures-util", "google-cloud-storage", "hex", "iroh", "log", "nalgebra", + "p2p", "rand 0.8.5", "rand 0.9.1", "redis", @@ -8233,6 +8313,7 @@ dependencies = [ "serde_json", "subtle", "tokio", + "tokio-util", "url", "utoipa", "uuid", @@ -8703,6 +8784,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-lexicon" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" + [[package]] name = "tempfile" version = "3.14.0" @@ -8716,6 +8803,28 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "test-log" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -8902,6 +9011,19 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-tungstenite" version = "0.24.0" @@ -8965,26 +9087,11 @@ dependencies = [ "tokio", ] -[[package]] -name = "toml" -version = "0.8.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ae329d1f08c4d17a59bed7ff5b5a769d062e64a62d34a3261b219e62cd5aae" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - [[package]] name = "toml_datetime" version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" -dependencies = [ - "serde", -] [[package]] name = "toml_edit" @@ -8993,19 +9100,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" dependencies = [ "indexmap 2.9.0", - "serde", - "serde_spanned", "toml_datetime", - "toml_write", "winnow", ] -[[package]] -name = "toml_write" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" - [[package]] name = "tower" version = "0.5.2" @@ -9226,12 +9324,6 @@ version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" -[[package]] -name = "unicode-bidi" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" - [[package]] name = "unicode-ident" version = "1.0.18" @@ -9265,6 +9357,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + [[package]] name = "universal-hash" version = "0.5.1" @@ -9325,7 +9423,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", - "idna 1.0.3", + "idna", "percent-encoding", "serde", ] @@ -9434,13 +9532,11 @@ dependencies = [ "env_logger", "futures", "hex", - "iroh", "lazy_static", "log", "mockito", - "nalgebra", + "p2p", "prometheus 0.14.0", - "rand 0.8.5", "rand 0.9.1", "redis", "redis-test", @@ -9452,50 +9548,7 @@ dependencies = [ "tempfile", "tokio", "tokio-util", - "toml", - "url", -] - -[[package]] -name = "validator" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b92f40481c04ff1f4f61f304d61793c7b56ff76ac1469f1beb199b1445b253bd" -dependencies = [ - "idna 0.4.0", - "lazy_static", - "regex", - "serde", - "serde_derive", - "serde_json", "url", - "validator_derive", -] - -[[package]] -name = "validator_derive" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc44ca3088bb3ba384d9aecf40c6a23a676ce23e09bdaca2073d99c207f864af" -dependencies = [ - "if_chain", - "lazy_static", - "proc-macro-error", - "proc-macro2", - "quote", - "regex", - "syn 1.0.109", - "validator_types", -] - -[[package]] -name = "validator_types" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "111abfe30072511849c5910134e8baf8dc05de4c0e5903d681cbd5c9c4d611e3" -dependencies = [ - "proc-macro2", - "syn 1.0.109", ] [[package]] @@ -10311,32 +10364,25 @@ dependencies = [ "alloy", "anyhow", "bollard", - "bytes", "chrono", "cid", "clap", "colored", "console", - "ctrlc", - "dashmap", "directories", "env_logger", "futures", - "futures-core", "futures-util", "hex", "homedir", - "indicatif", - "iroh", "lazy_static", "libc", "log", - "nalgebra", "nvml-wrapper", + "p2p", + "prime-core", "rand 0.8.5", "rand 0.9.1", - "rand_core 0.6.4", - "regex", "reqwest", "rust-ipfs", "serde", @@ -10353,15 +10399,12 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "toml", "tracing", - "tracing-log", "tracing-loki", "tracing-subscriber", "unicode-width", "url", "uuid", - "validator 0.16.1", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 00702d19..f3e786af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,12 +5,18 @@ members = [ "crates/validator", "crates/shared", "crates/orchestrator", + "crates/p2p", "crates/dev-utils", + "crates/prime-protocol-py", + "crates/prime-core", ] resolver = "2" [workspace.dependencies] shared = { path = "crates/shared" } +prime-core = { path = "crates/prime-core" } +p2p = { path = "crates/p2p" } + actix-web = "4.9.0" clap = { version = "4.5.27", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] } @@ -39,9 +45,9 @@ mockito = "1.7.0" iroh = "0.34.1" rand_v8 = { package = "rand", version = "0.8.5", features = ["std"] } rand_core_v6 = { package = "rand_core", version = "0.6.4", features = ["std"] } -ipld-core = "0.4" rust-ipfs = "0.14" cid = "0.11" +tracing = "0.1.41" [workspace.package] version = "0.3.11" @@ -55,3 +61,10 @@ manual_let_else = "warn" [workspace.lints.rust] unreachable_pub = "warn" + +[workspace.metadata.rust-analyzer] +# Help rust-analyzer with proc-macros +procMacro.enable = true +procMacro.attributes.enable = true +# Use a separate target directory for rust-analyzer +targetDir = true diff --git a/Makefile b/Makefile index decd07f6..5de39578 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,18 @@ up: @# Attach to session @tmux attach-session -t prime-dev +# Start Docker services and deploy contracts only +.PHONY: bootstrap +bootstrap: + @echo "Starting Docker services and deploying contracts..." + @# Start Docker services + @docker compose up -d reth redis --wait --wait-timeout 180 + @# Deploy contracts + @cd smart-contracts && sh deploy.sh && sh deploy_work_validation.sh && cd .. + @# Run setup + @$(MAKE) setup + @echo "Bootstrap complete - Docker services running and contracts deployed" + # Stop development environment .PHONY: down down: @@ -268,3 +280,12 @@ deregister-worker: set -a; source ${ENV_FILE}; set +a; \ cargo run --bin worker -- deregister --compute-pool-id $${WORKER_COMPUTE_POOL_ID} --private-key-provider $${PRIVATE_KEY_PROVIDER} --private-key-node $${PRIVATE_KEY_NODE} --rpc-url $${RPC_URL} +# Python Package +.PHONY: python-install +python-install: + @cd crates/prime-protocol-py && make install + +.PHONY: python-test +python-test: + @cd crates/prime-protocol-py && make test + diff --git a/crates/dev-utils/examples/compute_pool.rs b/crates/dev-utils/examples/compute_pool.rs index 2569980c..51658d59 100644 --- a/crates/dev-utils/examples/compute_pool.rs +++ b/crates/dev-utils/examples/compute_pool.rs @@ -68,17 +68,14 @@ async fn main() -> Result<()> { compute_limit, ) .await; - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); let rewards_distributor_address = contracts .compute_pool .get_reward_distributor_address(U256::from(0)) .await .unwrap(); - println!( - "Rewards distributor address: {:?}", - rewards_distributor_address - ); + println!("Rewards distributor address: {rewards_distributor_address:?}"); let rewards_distributor = RewardsDistributor::new( rewards_distributor_address, wallet.provider(), @@ -86,7 +83,7 @@ async fn main() -> Result<()> { ); let rate = U256::from(10000000000000000u64); let tx = rewards_distributor.set_reward_rate(rate).await; - println!("Setting reward rate: {:?}", tx); + println!("Setting reward rate: {tx:?}"); let reward_rate = rewards_distributor.get_reward_rate().await.unwrap(); println!( diff --git a/crates/dev-utils/examples/create_domain.rs b/crates/dev-utils/examples/create_domain.rs index 4365c764..d1da5ea2 100644 --- a/crates/dev-utils/examples/create_domain.rs +++ b/crates/dev-utils/examples/create_domain.rs @@ -59,6 +59,6 @@ async fn main() -> Result<()> { .await; println!("Creating domain: {}", args.domain_name); println!("Validation logic: {}", args.validation_logic); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/eject_node.rs b/crates/dev-utils/examples/eject_node.rs index e2ed03a3..142aa1cd 100644 --- a/crates/dev-utils/examples/eject_node.rs +++ b/crates/dev-utils/examples/eject_node.rs @@ -52,20 +52,20 @@ async fn main() -> Result<()> { .compute_registry .get_node(provider_address, node_address) .await; - println!("Node info: {:?}", node_info); + println!("Node info: {node_info:?}"); let tx = contracts .compute_pool .eject_node(args.pool_id, node_address) .await; println!("Ejected node {} from pool {}", args.node, args.pool_id); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); let node_info = contracts .compute_registry .get_node(provider_address, node_address) .await; - println!("Post ejection node info: {:?}", node_info); + println!("Post ejection node info: {node_info:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/get_node_info.rs b/crates/dev-utils/examples/get_node_info.rs index fec5f526..79c7c120 100644 --- a/crates/dev-utils/examples/get_node_info.rs +++ b/crates/dev-utils/examples/get_node_info.rs @@ -55,9 +55,6 @@ async fn main() -> Result<()> { .await .unwrap(); - println!( - "Node Active: {}, Validated: {}, In Pool: {}", - active, validated, is_node_in_pool - ); + println!("Node Active: {active}, Validated: {validated}, In Pool: {is_node_in_pool}"); Ok(()) } diff --git a/crates/dev-utils/examples/invalidate_work.rs b/crates/dev-utils/examples/invalidate_work.rs index 78154b07..c93c8cee 100644 --- a/crates/dev-utils/examples/invalidate_work.rs +++ b/crates/dev-utils/examples/invalidate_work.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { "Invalidated work in pool {} with penalty {}", args.pool_id, args.penalty ); - println!("Transaction hash: {:?}", tx); + println!("Transaction hash: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/mint_ai_token.rs b/crates/dev-utils/examples/mint_ai_token.rs index 5e572b40..bc43b78d 100644 --- a/crates/dev-utils/examples/mint_ai_token.rs +++ b/crates/dev-utils/examples/mint_ai_token.rs @@ -45,9 +45,9 @@ async fn main() -> Result<()> { let amount = U256::from(args.amount) * Unit::ETHER.wei(); let tx = contracts.ai_token.mint(address, amount).await; println!("Minting to address: {}", args.address); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); let balance = contracts.ai_token.balance_of(address).await; - println!("Balance: {:?}", balance); + println!("Balance: {balance:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/set_min_stake_amount.rs b/crates/dev-utils/examples/set_min_stake_amount.rs index 82644e61..2858f5c7 100644 --- a/crates/dev-utils/examples/set_min_stake_amount.rs +++ b/crates/dev-utils/examples/set_min_stake_amount.rs @@ -36,13 +36,13 @@ async fn main() -> Result<()> { .unwrap(); let min_stake_amount = U256::from(args.min_stake_amount) * Unit::ETHER.wei(); - println!("Min stake amount: {}", min_stake_amount); + println!("Min stake amount: {min_stake_amount}"); let tx = contracts .prime_network .set_stake_minimum(min_stake_amount) .await; - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/start_compute_pool.rs b/crates/dev-utils/examples/start_compute_pool.rs index b11e2b2c..a94c0b6f 100644 --- a/crates/dev-utils/examples/start_compute_pool.rs +++ b/crates/dev-utils/examples/start_compute_pool.rs @@ -41,6 +41,6 @@ async fn main() -> Result<()> { .start_compute_pool(U256::from(args.pool_id)) .await; println!("Started compute pool with id: {}", args.pool_id); - println!("Transaction: {:?}", tx); + println!("Transaction: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/submit_work.rs b/crates/dev-utils/examples/submit_work.rs index aa3b489c..0fcf20d0 100644 --- a/crates/dev-utils/examples/submit_work.rs +++ b/crates/dev-utils/examples/submit_work.rs @@ -64,7 +64,7 @@ async fn main() -> Result<()> { "Submitted work for node {} in pool {}", args.node, args.pool_id ); - println!("Transaction hash: {:?}", tx); + println!("Transaction hash: {tx:?}"); Ok(()) } diff --git a/crates/dev-utils/examples/test_concurrent_calls.rs b/crates/dev-utils/examples/test_concurrent_calls.rs index 47f7bbea..1bef230a 100644 --- a/crates/dev-utils/examples/test_concurrent_calls.rs +++ b/crates/dev-utils/examples/test_concurrent_calls.rs @@ -38,7 +38,7 @@ async fn main() -> Result<()> { let wallet = Arc::new(Wallet::new(&args.key, Url::parse(&args.rpc_url)?).unwrap()); let price = wallet.provider.get_gas_price().await?; - println!("Gas price: {:?}", price); + println!("Gas price: {price:?}"); let current_nonce = wallet .provider @@ -50,8 +50,8 @@ async fn main() -> Result<()> { .block_id(BlockId::Number(BlockNumberOrTag::Pending)) .await?; - println!("Pending nonce: {:?}", pending_nonce); - println!("Current nonce: {:?}", current_nonce); + println!("Pending nonce: {pending_nonce:?}"); + println!("Current nonce: {current_nonce:?}"); // Unfortunately have to build all contracts atm let contracts = Arc::new( @@ -67,7 +67,7 @@ async fn main() -> Result<()> { let address = Address::from_str(&args.address).unwrap(); let amount = U256::from(args.amount) * Unit::ETHER.wei(); let random = (rand::random::() % 10) + 1; - println!("Random: {:?}", random); + println!("Random: {random:?}"); let contracts_one = contracts.clone(); let wallet_one = wallet.clone(); @@ -80,7 +80,7 @@ async fn main() -> Result<()> { let tx = retry_call(mint_call, 5, wallet_one.provider(), None) .await .unwrap(); - println!("Transaction hash I: {:?}", tx); + println!("Transaction hash I: {tx:?}"); }); let contracts_two = contracts.clone(); @@ -93,11 +93,11 @@ async fn main() -> Result<()> { let tx = retry_call(mint_call_two, 5, wallet_two.provider(), None) .await .unwrap(); - println!("Transaction hash II: {:?}", tx); + println!("Transaction hash II: {tx:?}"); }); let balance = contracts.ai_token.balance_of(address).await.unwrap(); - println!("Balance: {:?}", balance); + println!("Balance: {balance:?}"); tokio::time::sleep(tokio::time::Duration::from_secs(40)).await; Ok(()) } diff --git a/crates/discovery/src/api/routes/node.rs b/crates/discovery/src/api/routes/node.rs index b2cf780f..aa6ca45a 100644 --- a/crates/discovery/src/api/routes/node.rs +++ b/crates/discovery/src/api/routes/node.rs @@ -465,12 +465,10 @@ mod tests { assert_eq!(body.data, "Node registered successfully"); let nodes = app_state.node_store.get_nodes().await; - let nodes = match nodes { - Ok(nodes) => nodes, - Err(_) => { - panic!("Error getting nodes"); - } + let Ok(nodes) = nodes else { + panic!("Error getting nodes"); }; + assert_eq!(nodes.len(), 1); assert_eq!(nodes[0].id, node.id); assert_eq!(nodes[0].last_updated, None); @@ -611,12 +609,10 @@ mod tests { assert_eq!(body.data, "Node registered successfully"); let nodes = app_state.node_store.get_nodes().await; - let nodes = match nodes { - Ok(nodes) => nodes, - Err(_) => { - panic!("Error getting nodes"); - } + let Ok(nodes) = nodes else { + panic!("Error getting nodes"); }; + assert_eq!(nodes.len(), 1); assert_eq!(nodes[0].id, node.id); } diff --git a/crates/discovery/src/chainsync/sync.rs b/crates/discovery/src/chainsync/sync.rs index 6101c87a..1120d3cb 100644 --- a/crates/discovery/src/chainsync/sync.rs +++ b/crates/discovery/src/chainsync/sync.rs @@ -155,7 +155,7 @@ async fn sync_single_node( })?; let balance = provider.get_balance(node_address).await.map_err(|e| { - error!("Error retrieving balance for node {}: {}", node_address, e); + error!("Error retrieving balance for node {node_address}: {e}"); anyhow::anyhow!("Failed to retrieve node balance") })?; n.latest_balance = Some(balance); @@ -166,8 +166,7 @@ async fn sync_single_node( .await .map_err(|e| { error!( - "Error retrieving node info for provider {} and node {}: {}", - provider_address, node_address, e + "Error retrieving node info for provider {provider_address} and node {node_address}: {e}" ); anyhow::anyhow!("Failed to retrieve node info") })?; @@ -177,10 +176,7 @@ async fn sync_single_node( .get_provider(provider_address) .await .map_err(|e| { - error!( - "Error retrieving provider info for {}: {}", - provider_address, e - ); + error!("Error retrieving provider info for {provider_address}: {e}"); anyhow::anyhow!("Failed to retrieve provider info") })?; diff --git a/crates/discovery/src/store/redis.rs b/crates/discovery/src/store/redis.rs index 508815c2..c0a0c36b 100644 --- a/crates/discovery/src/store/redis.rs +++ b/crates/discovery/src/store/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/orchestrator/Cargo.toml b/crates/orchestrator/Cargo.toml index 6ac53140..ce733ee6 100644 --- a/crates/orchestrator/Cargo.toml +++ b/crates/orchestrator/Cargo.toml @@ -7,35 +7,35 @@ edition.workspace = true workspace = true [dependencies] +p2p = { workspace = true} +shared = { workspace = true } + actix-web = { workspace = true } -actix-web-prometheus = "0.1.2" alloy = { workspace = true } anyhow = { workspace = true } -async-trait = "0.1.88" -base64 = "0.22.1" chrono = { workspace = true, features = ["serde"] } clap = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } -google-cloud-auth = "0.18.0" -google-cloud-storage = "0.24.0" hex = { workspace = true } log = { workspace = true } -prometheus = "0.14.0" -rand = "0.9.0" redis = { workspace = true, features = ["tokio-comp"] } redis-test = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -shared = { workspace = true } tokio = { workspace = true } +tokio-util = { workspace = true } url = { workspace = true } +uuid = { workspace = true } + +actix-web-prometheus = "0.1.2" +google-cloud-auth = "0.18.0" +google-cloud-storage = "0.24.0" +prometheus = "0.14.0" +rand = "0.9.0" utoipa = { version = "5.3.0", features = ["actix_extras", "chrono", "uuid"] } utoipa-swagger-ui = { version = "9.0.2", features = ["actix-web", "debug-embed", "reqwest", "vendored"] } -uuid = { workspace = true } -iroh = { workspace = true } -rand_v8 = { workspace = true } [dev-dependencies] mockito = { workspace = true } diff --git a/crates/orchestrator/src/api/routes/groups.rs b/crates/orchestrator/src/api/routes/groups.rs index 44b22cd9..414f524a 100644 --- a/crates/orchestrator/src/api/routes/groups.rs +++ b/crates/orchestrator/src/api/routes/groups.rs @@ -236,9 +236,6 @@ async fn fetch_node_logs_p2p( match node { Some(node) => { - // Check if P2P client is available - let p2p_client = app_state.p2p_client.clone(); - // Check if node has P2P information let (worker_p2p_id, worker_p2p_addresses) = match (&node.worker_p2p_id, &node.worker_p2p_addresses) { @@ -254,11 +251,22 @@ async fn fetch_node_logs_p2p( }; // Send P2P request for task logs - match tokio::time::timeout( - Duration::from_secs(NODE_REQUEST_TIMEOUT), - p2p_client.get_task_logs(node_address, worker_p2p_id, worker_p2p_addresses), - ) - .await + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let get_task_logs_request = crate::p2p::GetTaskLogsRequest { + worker_wallet_address: node_address, + worker_p2p_id: worker_p2p_id.clone(), + worker_addresses: worker_p2p_addresses.clone(), + response_tx, + }; + if let Err(e) = app_state.get_task_logs_tx.send(get_task_logs_request).await { + error!("Failed to send GetTaskLogsRequest for node {node_address}: {e}"); + return json!({ + "success": false, + "error": format!("Failed to send request: {}", e), + "status": node.status.to_string() + }); + }; + match tokio::time::timeout(Duration::from_secs(NODE_REQUEST_TIMEOUT), response_rx).await { Ok(Ok(log_lines)) => { json!({ diff --git a/crates/orchestrator/src/api/routes/heartbeat.rs b/crates/orchestrator/src/api/routes/heartbeat.rs index a8110e61..4d6261f9 100644 --- a/crates/orchestrator/src/api/routes/heartbeat.rs +++ b/crates/orchestrator/src/api/routes/heartbeat.rs @@ -404,7 +404,7 @@ mod tests { let task = match task.try_into() { Ok(task) => task, - Err(e) => panic!("Failed to convert TaskRequest to Task: {}", e), + Err(e) => panic!("Failed to convert TaskRequest to Task: {e}"), }; let _ = app_state.store_context.task_store.add_task(task).await; diff --git a/crates/orchestrator/src/api/routes/nodes.rs b/crates/orchestrator/src/api/routes/nodes.rs index a260706a..9debddde 100644 --- a/crates/orchestrator/src/api/routes/nodes.rs +++ b/crates/orchestrator/src/api/routes/nodes.rs @@ -164,11 +164,22 @@ async fn restart_node_task(node_id: web::Path, app_state: Data .as_ref() .expect("worker_p2p_addresses should be present"); - match app_state - .p2p_client - .restart_task(node_address, p2p_id, p2p_addresses) - .await - { + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let restart_task_request = crate::p2p::RestartTaskRequest { + worker_wallet_address: node.address, + worker_p2p_id: p2p_id.clone(), + worker_addresses: p2p_addresses.clone(), + response_tx, + }; + if let Err(e) = app_state.restart_task_tx.send(restart_task_request).await { + error!("Failed to send restart task request: {e}"); + return HttpResponse::InternalServerError().json(json!({ + "success": false, + "error": "Failed to send restart task request" + })); + } + + match response_rx.await { Ok(_) => HttpResponse::Ok().json(json!({ "success": true, "message": "Task restarted successfully" @@ -240,11 +251,22 @@ async fn get_node_logs(node_id: web::Path, app_state: Data) -> })); }; - match app_state - .p2p_client - .get_task_logs(node_address, p2p_id, p2p_addresses) - .await - { + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let get_task_logs_request = crate::p2p::GetTaskLogsRequest { + worker_wallet_address: node.address, + worker_p2p_id: p2p_id.clone(), + worker_addresses: p2p_addresses.clone(), + response_tx, + }; + if let Err(e) = app_state.get_task_logs_tx.send(get_task_logs_request).await { + error!("Failed to send get task logs request: {e}"); + return HttpResponse::InternalServerError().json(json!({ + "success": false, + "error": "Failed to send get task logs request" + })); + } + + match response_rx.await { Ok(logs) => HttpResponse::Ok().json(json!({ "success": true, "logs": logs diff --git a/crates/orchestrator/src/api/routes/task.rs b/crates/orchestrator/src/api/routes/task.rs index 7cff4b6d..fa167dc7 100644 --- a/crates/orchestrator/src/api/routes/task.rs +++ b/crates/orchestrator/src/api/routes/task.rs @@ -315,8 +315,8 @@ mod tests { // Add tasks in sequence with delays for i in 1..=3 { let task: Task = TaskRequest { - image: format!("test{}", i), - name: format!("test{}", i), + image: format!("test{i}"), + name: format!("test{i}"), ..Default::default() } .try_into() diff --git a/crates/orchestrator/src/api/server.rs b/crates/orchestrator/src/api/server.rs index 095bcb6c..fc5943c9 100644 --- a/crates/orchestrator/src/api/server.rs +++ b/crates/orchestrator/src/api/server.rs @@ -5,7 +5,7 @@ use crate::api::routes::task::tasks_routes; use crate::api::routes::{heartbeat::heartbeat_routes, metrics::metrics_routes}; use crate::metrics::MetricsContext; use crate::models::node::NodeStatus; -use crate::p2p::client::P2PClient; +use crate::p2p::{GetTaskLogsRequest, RestartTaskRequest}; use crate::plugins::node_groups::NodeGroupsPlugin; use crate::scheduler::Scheduler; use crate::store::core::{RedisStore, StoreContext}; @@ -23,6 +23,7 @@ use shared::utils::StorageProvider; use shared::web3::contracts::core::builder::Contracts; use shared::web3::wallet::WalletProvider; use std::sync::Arc; +use tokio::sync::mpsc::Sender; use utoipa::{ openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}, Modify, OpenApi, @@ -116,17 +117,18 @@ async fn health_check(data: web::Data) -> HttpResponse { } pub(crate) struct AppState { - pub store_context: Arc, - pub storage_provider: Option>, - pub heartbeats: Arc, - pub redis_store: Arc, - pub hourly_upload_limit: i64, - pub contracts: Option>, - pub pool_id: u32, - pub scheduler: Scheduler, - pub node_groups_plugin: Option>, - pub metrics: Arc, - pub p2p_client: Arc, + pub(crate) store_context: Arc, + pub(crate) storage_provider: Option>, + pub(crate) heartbeats: Arc, + pub(crate) redis_store: Arc, + pub(crate) hourly_upload_limit: i64, + pub(crate) contracts: Option>, + pub(crate) pool_id: u32, + pub(crate) scheduler: Scheduler, + pub(crate) node_groups_plugin: Option>, + pub(crate) metrics: Arc, + pub(crate) get_task_logs_tx: Sender, + pub(crate) restart_task_tx: Sender, } #[allow(clippy::too_many_arguments)] @@ -145,7 +147,8 @@ pub async fn start_server( scheduler: Scheduler, node_groups_plugin: Option>, metrics: Arc, - p2p_client: Arc, + get_task_logs_tx: Sender, + restart_task_tx: Sender, ) -> Result<(), Error> { info!("Starting server at http://{host}:{port}"); let app_state = Data::new(AppState { @@ -159,7 +162,8 @@ pub async fn start_server( scheduler, node_groups_plugin, metrics, - p2p_client, + get_task_logs_tx, + restart_task_tx, }); let node_store = app_state.store_context.node_store.clone(); let node_store_clone = node_store.clone(); diff --git a/crates/orchestrator/src/api/tests/helper.rs b/crates/orchestrator/src/api/tests/helper.rs index ca2e65c1..f4204262 100644 --- a/crates/orchestrator/src/api/tests/helper.rs +++ b/crates/orchestrator/src/api/tests/helper.rs @@ -18,12 +18,12 @@ use std::sync::Arc; use url::Url; #[cfg(test)] -pub async fn create_test_app_state() -> Data { +pub(crate) async fn create_test_app_state() -> Data { use shared::utils::MockStorageProvider; use crate::{ - metrics::MetricsContext, p2p::client::P2PClient, scheduler::Scheduler, - utils::loop_heartbeats::LoopHeartbeats, ServerMode, + metrics::MetricsContext, scheduler::Scheduler, utils::loop_heartbeats::LoopHeartbeats, + ServerMode, }; let store = Arc::new(RedisStore::new_test()); @@ -46,12 +46,8 @@ pub async fn create_test_app_state() -> Data { let mock_storage = MockStorageProvider::new(); let storage_provider = Arc::new(mock_storage); let metrics = Arc::new(MetricsContext::new(1.to_string())); - let wallet = Wallet::new( - "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", - Url::parse("http://localhost:8545").unwrap(), - ) - .unwrap(); - let p2p_client = Arc::new(P2PClient::new(wallet.clone()).await.unwrap()); + let (get_task_logs_tx, _) = tokio::sync::mpsc::channel(1); + let (restart_task_tx, _) = tokio::sync::mpsc::channel(1); Data::new(AppState { store_context: store_context.clone(), @@ -64,17 +60,17 @@ pub async fn create_test_app_state() -> Data { scheduler, node_groups_plugin: None, metrics, - p2p_client: p2p_client.clone(), + get_task_logs_tx, + restart_task_tx, }) } #[cfg(test)] -pub async fn create_test_app_state_with_nodegroups() -> Data { +pub(crate) async fn create_test_app_state_with_nodegroups() -> Data { use shared::utils::MockStorageProvider; use crate::{ metrics::MetricsContext, - p2p::client::P2PClient, plugins::node_groups::{NodeGroupConfiguration, NodeGroupsPlugin}, scheduler::Scheduler, utils::loop_heartbeats::LoopHeartbeats, @@ -116,12 +112,8 @@ pub async fn create_test_app_state_with_nodegroups() -> Data { let mock_storage = MockStorageProvider::new(); let storage_provider = Arc::new(mock_storage); let metrics = Arc::new(MetricsContext::new(1.to_string())); - let wallet = Wallet::new( - "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", - Url::parse("http://localhost:8545").unwrap(), - ) - .unwrap(); - let p2p_client = Arc::new(P2PClient::new(wallet.clone()).await.unwrap()); + let (get_task_logs_tx, _) = tokio::sync::mpsc::channel(1); + let (restart_task_tx, _) = tokio::sync::mpsc::channel(1); Data::new(AppState { store_context: store_context.clone(), @@ -134,12 +126,13 @@ pub async fn create_test_app_state_with_nodegroups() -> Data { scheduler, node_groups_plugin, metrics, - p2p_client: p2p_client.clone(), + get_task_logs_tx, + restart_task_tx, }) } #[cfg(test)] -pub fn setup_contract() -> Contracts { +pub(crate) fn setup_contract() -> Contracts { let coordinator_key = "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97"; let rpc_url: Url = Url::parse("http://localhost:8545").unwrap(); let wallet = Wallet::new(coordinator_key, rpc_url).unwrap(); @@ -154,12 +147,12 @@ pub fn setup_contract() -> Contracts { } #[cfg(test)] -pub async fn create_test_app_state_with_metrics() -> Data { +pub(crate) async fn create_test_app_state_with_metrics() -> Data { use shared::utils::MockStorageProvider; use crate::{ - metrics::MetricsContext, p2p::client::P2PClient, scheduler::Scheduler, - utils::loop_heartbeats::LoopHeartbeats, ServerMode, + metrics::MetricsContext, scheduler::Scheduler, utils::loop_heartbeats::LoopHeartbeats, + ServerMode, }; let store = Arc::new(RedisStore::new_test()); @@ -182,12 +175,8 @@ pub async fn create_test_app_state_with_metrics() -> Data { let mock_storage = MockStorageProvider::new(); let storage_provider = Arc::new(mock_storage); let metrics = Arc::new(MetricsContext::new("0".to_string())); - let wallet = Wallet::new( - "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", - Url::parse("http://localhost:8545").unwrap(), - ) - .unwrap(); - let p2p_client = Arc::new(P2PClient::new(wallet.clone()).await.unwrap()); + let (get_task_logs_tx, _) = tokio::sync::mpsc::channel(1); + let (restart_task_tx, _) = tokio::sync::mpsc::channel(1); Data::new(AppState { store_context: store_context.clone(), @@ -200,6 +189,7 @@ pub async fn create_test_app_state_with_metrics() -> Data { scheduler, node_groups_plugin: None, metrics, - p2p_client: p2p_client.clone(), + get_task_logs_tx, + restart_task_tx, }) } diff --git a/crates/orchestrator/src/discovery/monitor.rs b/crates/orchestrator/src/discovery/monitor.rs index 56fed833..d1ea3133 100644 --- a/crates/orchestrator/src/discovery/monitor.rs +++ b/crates/orchestrator/src/discovery/monitor.rs @@ -384,15 +384,12 @@ impl DiscoveryMonitor { if let Some(balance) = discovery_node.latest_balance { if balance == U256::ZERO { - info!( - "Node {} has zero balance, marking as low balance", - node_address - ); + info!("Node {node_address} has zero balance, marking as low balance"); if let Err(e) = self .update_node_status(&node_address, NodeStatus::LowBalance) .await { - error!("Error updating node status: {}", e); + error!("Error updating node status: {e}"); } } } diff --git a/crates/orchestrator/src/lib.rs b/crates/orchestrator/src/lib.rs index 5f82d58d..19d13eba 100644 --- a/crates/orchestrator/src/lib.rs +++ b/crates/orchestrator/src/lib.rs @@ -16,7 +16,7 @@ pub use metrics::sync_service::MetricsSyncService; pub use metrics::webhook_sender::MetricsWebhookSender; pub use metrics::MetricsContext; pub use node::invite::NodeInviter; -pub use p2p::client::P2PClient; +pub use p2p::Service as P2PService; pub use plugins::node_groups::NodeGroupConfiguration; pub use plugins::node_groups::NodeGroupsPlugin; pub use plugins::webhook::WebhookConfig; diff --git a/crates/orchestrator/src/main.rs b/crates/orchestrator/src/main.rs index f9beaccb..5f8e2af2 100644 --- a/crates/orchestrator/src/main.rs +++ b/crates/orchestrator/src/main.rs @@ -9,12 +9,13 @@ use shared::web3::contracts::core::builder::ContractBuilder; use shared::web3::wallet::Wallet; use std::sync::Arc; use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; use url::Url; use orchestrator::{ start_server, DiscoveryMonitor, LoopHeartbeats, MetricsContext, MetricsSyncService, MetricsWebhookSender, NodeGroupConfiguration, NodeGroupsPlugin, NodeInviter, NodeStatusUpdater, - P2PClient, RedisStore, Scheduler, SchedulerPlugin, ServerMode, StatusUpdatePlugin, + P2PService, RedisStore, Scheduler, SchedulerPlugin, ServerMode, StatusUpdatePlugin, StoreContext, WebhookConfig, WebhookPlugin, }; @@ -91,6 +92,10 @@ struct Args { /// Max healthy nodes with same endpoint #[arg(long, default_value = "1")] max_healthy_nodes_with_same_endpoint: u32, + + /// Libp2p port + #[arg(long, default_value = "4004")] + libp2p_port: u16, } #[tokio::main] @@ -143,7 +148,27 @@ async fn main() -> Result<()> { let store = Arc::new(RedisStore::new(&args.redis_store_url)); let store_context = Arc::new(StoreContext::new(store.clone())); - let p2p_client = Arc::new(P2PClient::new(wallet.clone()).await.unwrap()); + let keypair = p2p::Keypair::generate_ed25519(); + let cancellation_token = CancellationToken::new(); + let (p2p_service, invite_tx, get_task_logs_tx, restart_task_tx) = { + match P2PService::new( + keypair, + args.libp2p_port, + cancellation_token.clone(), + wallet.clone(), + ) { + Ok(res) => { + info!("p2p service initialized successfully"); + res + } + Err(e) => { + error!("failed to initialize p2p service: {e}"); + std::process::exit(1); + } + } + }; + + tokio::task::spawn(p2p_service.run()); let contracts = ContractBuilder::new(wallet.provider()) .with_compute_registry() @@ -297,24 +322,29 @@ async fn main() -> Result<()> { let inviter_store_context = store_context.clone(); let inviter_heartbeats = heartbeats.clone(); - tasks.spawn({ - let wallet = wallet.clone(); - let p2p_client = p2p_client.clone(); - async move { - let inviter = NodeInviter::new( - wallet, - compute_pool_id, - domain_id, - args.host.as_deref(), - Some(&args.port), - args.url.as_deref(), - inviter_store_context.clone(), - inviter_heartbeats.clone(), - p2p_client, - ); - inviter.run().await + let wallet = wallet.clone(); + let inviter = match NodeInviter::new( + wallet, + compute_pool_id, + domain_id, + args.host.as_deref(), + Some(&args.port), + args.url.as_deref(), + inviter_store_context.clone(), + inviter_heartbeats.clone(), + invite_tx, + ) { + Ok(inviter) => { + info!("Node inviter initialized successfully"); + inviter } - }); + Err(e) => { + error!("Failed to initialize node inviter: {e}"); + std::process::exit(1); + } + }; + + tasks.spawn(async move { inviter.run().await }); // Create status_update_plugins for status updater let mut status_updater_plugins: Vec = vec![]; @@ -387,7 +417,8 @@ async fn main() -> Result<()> { scheduler, node_groups_plugin, metrics_context, - p2p_client, + get_task_logs_tx, + restart_task_tx, ) => { if let Err(e) = res { error!("Server error: {e}"); @@ -403,6 +434,8 @@ async fn main() -> Result<()> { } } + // TODO: use cancellation token to gracefully shutdown tasks + cancellation_token.cancel(); tasks.shutdown().await; Ok(()) } diff --git a/crates/orchestrator/src/node/invite.rs b/crates/orchestrator/src/node/invite.rs index 17ae4207..8391d047 100644 --- a/crates/orchestrator/src/node/invite.rs +++ b/crates/orchestrator/src/node/invite.rs @@ -1,40 +1,40 @@ use crate::models::node::NodeStatus; use crate::models::node::OrchestratorNode; -use crate::p2p::client::P2PClient; +use crate::p2p::InviteRequest as InviteRequestWithMetadata; use crate::store::core::StoreContext; use crate::utils::loop_heartbeats::LoopHeartbeats; use alloy::primitives::utils::keccak256 as keccak; use alloy::primitives::U256; use alloy::signers::Signer; -use anyhow::Result; +use anyhow::{bail, Result}; use futures::stream; use futures::StreamExt; use log::{debug, error, info, warn}; -use shared::models::invite::InviteRequest; +use p2p::InviteRequest; +use p2p::InviteRequestUrl; use shared::web3::wallet::Wallet; use std::sync::Arc; use std::time::SystemTime; use std::time::UNIX_EPOCH; +use tokio::sync::mpsc::Sender; use tokio::time::{interval, Duration}; // Timeout constants const DEFAULT_INVITE_CONCURRENT_COUNT: usize = 32; // Max concurrent count of nodes being invited -pub struct NodeInviter<'a> { +pub struct NodeInviter { wallet: Wallet, pool_id: u32, domain_id: u32, - host: Option<&'a str>, - port: Option<&'a u16>, - url: Option<&'a str>, + url: InviteRequestUrl, store_context: Arc, heartbeats: Arc, - p2p_client: Arc, + invite_tx: Sender, } -impl<'a> NodeInviter<'a> { +impl NodeInviter { #[allow(clippy::too_many_arguments)] - pub fn new( + pub fn new<'a>( wallet: Wallet, pool_id: u32, domain_id: u32, @@ -43,19 +43,31 @@ impl<'a> NodeInviter<'a> { url: Option<&'a str>, store_context: Arc, heartbeats: Arc, - p2p_client: Arc, - ) -> Self { - Self { + invite_tx: Sender, + ) -> Result { + let url = if let Some(url) = url { + InviteRequestUrl::MasterUrl(url.to_string()) + } else { + let Some(host) = host else { + bail!("either host or url must be provided"); + }; + + let Some(port) = port else { + bail!("either port or url must be provided"); + }; + + InviteRequestUrl::MasterIpPort(host.to_string(), *port) + }; + + Ok(Self { wallet, pool_id, domain_id, - host, - port, url, store_context, heartbeats, - p2p_client, - } + invite_tx, + }) } pub async fn run(&self) -> Result<()> { @@ -71,7 +83,7 @@ impl<'a> NodeInviter<'a> { } } - async fn _generate_invite( + async fn generate_invite( &self, node: &OrchestratorNode, nonce: [u8; 32], @@ -102,7 +114,7 @@ impl<'a> NodeInviter<'a> { Ok(signature) } - async fn _send_invite(&self, node: &OrchestratorNode) -> Result<(), anyhow::Error> { + async fn send_invite(&self, node: &OrchestratorNode) -> Result<(), anyhow::Error> { if node.worker_p2p_id.is_none() || node.worker_p2p_addresses.is_none() { return Err(anyhow::anyhow!("Node does not have p2p information")); } @@ -120,21 +132,11 @@ impl<'a> NodeInviter<'a> { ) .to_be_bytes(); - let invite_signature = self._generate_invite(node, nonce, expiration).await?; + let invite_signature = self.generate_invite(node, nonce, expiration).await?; let payload = InviteRequest { invite: hex::encode(invite_signature), pool_id: self.pool_id, - master_url: self.url.map(|u| u.to_string()), - master_ip: if self.url.is_none() { - self.host.map(|h| h.to_string()) - } else { - None - }, - master_port: if self.url.is_none() { - self.port.copied() - } else { - None - }, + url: self.url.clone(), timestamp: SystemTime::now() .duration_since(UNIX_EPOCH) .map_err(|e| anyhow::anyhow!("System time error: {}", e))? @@ -145,11 +147,19 @@ impl<'a> NodeInviter<'a> { info!("Sending invite to node: {p2p_id}"); - match self - .p2p_client - .invite_worker(node.address, p2p_id, p2p_addresses, payload) + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let invite = InviteRequestWithMetadata { + worker_wallet_address: node.address, + worker_p2p_id: p2p_id.clone(), + worker_addresses: p2p_addresses.clone(), + invite: payload, + response_tx, + }; + self.invite_tx + .send(invite) .await - { + .map_err(|_| anyhow::anyhow!("failed to send invite request"))?; + match response_rx.await { Ok(_) => { info!("Successfully invited node"); if let Err(e) = self @@ -182,7 +192,7 @@ impl<'a> NodeInviter<'a> { let invited_nodes = stream::iter(nodes.into_iter().map(|node| async move { info!("Processing node {:?}", node.address); - match self._send_invite(&node).await { + match self.send_invite(&node).await { Ok(_) => { info!("Successfully processed node {:?}", node.address); Ok(()) diff --git a/crates/orchestrator/src/p2p/client.rs b/crates/orchestrator/src/p2p/client.rs deleted file mode 100644 index 39810151..00000000 --- a/crates/orchestrator/src/p2p/client.rs +++ /dev/null @@ -1,102 +0,0 @@ -use alloy::primitives::Address; -use anyhow::Result; -use log::{info, warn}; -use shared::models::invite::InviteRequest; -use shared::p2p::{client::P2PClient as SharedP2PClient, messages::P2PMessage}; -use shared::web3::wallet::Wallet; - -pub struct P2PClient { - shared_client: SharedP2PClient, -} - -impl P2PClient { - pub async fn new(wallet: Wallet) -> Result { - let shared_client = SharedP2PClient::new(wallet).await?; - Ok(Self { shared_client }) - } - - pub async fn invite_worker( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - invite: InviteRequest, - ) -> Result<()> { - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::Invite(invite), - 20, - ) - .await?; - - match response { - P2PMessage::InviteResponse { status, error } => { - if status == "ok" { - info!("Successfully invited worker {worker_p2p_id}"); - Ok(()) - } else { - let error_msg = error.unwrap_or_else(|| "Unknown error".to_string()); - warn!("Failed to invite worker {worker_p2p_id}: {error_msg}"); - Err(anyhow::anyhow!("Invite failed: {}", error_msg)) - } - } - _ => Err(anyhow::anyhow!("Unexpected response type for invite")), - } - } - - pub async fn get_task_logs( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - ) -> Result> { - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::GetTaskLogs, - 20, - ) - .await?; - - match response { - P2PMessage::GetTaskLogsResponse { logs } => { - logs.map_err(|e| anyhow::anyhow!("Failed to get task logs: {}", e)) - } - _ => Err(anyhow::anyhow!( - "Unexpected response type for get_task_logs" - )), - } - } - - pub async fn restart_task( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - ) -> Result<()> { - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::RestartTask, - 25, - ) - .await?; - - match response { - P2PMessage::RestartTaskResponse { result } => { - result.map_err(|e| anyhow::anyhow!("Failed to restart task: {}", e)) - } - _ => Err(anyhow::anyhow!("Unexpected response type for restart_task")), - } - } -} diff --git a/crates/orchestrator/src/p2p/mod.rs b/crates/orchestrator/src/p2p/mod.rs index 1d331315..c11ca2bf 100644 --- a/crates/orchestrator/src/p2p/mod.rs +++ b/crates/orchestrator/src/p2p/mod.rs @@ -1 +1,171 @@ -pub(crate) mod client; +use anyhow::{bail, Context as _, Result}; +use futures::stream::FuturesUnordered; +use futures::FutureExt; +use p2p::{Keypair, Protocols}; +use shared::p2p::OutgoingRequest; +use shared::p2p::Service as P2PService; +use shared::web3::wallet::Wallet; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio_util::sync::CancellationToken; + +pub struct Service { + inner: P2PService, + outgoing_message_tx: Sender, + invite_rx: Receiver, + get_task_logs_rx: Receiver, + restart_task_rx: Receiver, +} + +impl Service { + #[allow(clippy::type_complexity)] + pub fn new( + keypair: Keypair, + port: u16, + cancellation_token: CancellationToken, + wallet: Wallet, + ) -> Result<( + Self, + Sender, + Sender, + Sender, + )> { + let (invite_tx, invite_rx) = tokio::sync::mpsc::channel(100); + let (get_task_logs_tx, get_task_logs_rx) = tokio::sync::mpsc::channel(100); + let (restart_task_tx, restart_task_rx) = tokio::sync::mpsc::channel(100); + let (inner, outgoing_message_tx) = P2PService::new( + keypair, + port, + cancellation_token.clone(), + wallet, + Protocols::new() + .with_invite() + .with_get_task_logs() + .with_restart() + .with_authentication(), + ) + .context("failed to create p2p service")?; + Ok(( + Self { + inner, + outgoing_message_tx, + invite_rx, + get_task_logs_rx, + restart_task_rx, + }, + invite_tx, + get_task_logs_tx, + restart_task_tx, + )) + } + + pub async fn run(self) -> Result<()> { + use futures::StreamExt as _; + + let Self { + inner, + outgoing_message_tx, + mut invite_rx, + mut get_task_logs_rx, + mut restart_task_rx, + } = self; + + tokio::task::spawn(inner.run()); + + let mut futures = FuturesUnordered::new(); + + loop { + tokio::select! { + Some(request) = invite_rx.recv() => { + let (incoming_resp_tx, incoming_resp_rx) = tokio::sync::oneshot::channel(); + let fut = async move { + let p2p::Response::Invite(resp) = incoming_resp_rx.await.context("outgoing request tx channel was dropped")? else { + bail!("unexpected response type for invite request"); + }; + request.response_tx.send(resp).map_err(|_|anyhow::anyhow!("caller dropped response channel"))?; + Ok(()) + }.boxed(); + futures.push(fut); + + let outgoing_request = OutgoingRequest { + peer_wallet_address: request.worker_wallet_address, + peer_id: request.worker_p2p_id, + multiaddrs: request.worker_addresses, + request: request.invite.into(), + response_tx: incoming_resp_tx, + }; + outgoing_message_tx.send(outgoing_request).await + .context("failed to send outgoing invite request")?; + } + Some(request) = get_task_logs_rx.recv() => { + let (incoming_resp_tx, incoming_resp_rx) = tokio::sync::oneshot::channel(); + let fut = async move { + let p2p::Response::GetTaskLogs(resp) = incoming_resp_rx.await.context("outgoing request tx channel was dropped")? else { + bail!("unexpected response type for get task logs request"); + }; + request.response_tx.send(resp).map_err(|_|anyhow::anyhow!("caller dropped response channel"))?; + Ok(()) + }.boxed(); + futures.push(fut); + + let outgoing_request = OutgoingRequest { + peer_wallet_address: request.worker_wallet_address, + peer_id: request.worker_p2p_id, + multiaddrs: request.worker_addresses, + request: p2p::Request::GetTaskLogs, + response_tx: incoming_resp_tx, + }; + outgoing_message_tx.send(outgoing_request).await + .context("failed to send outgoing get task logs request")?; + } + Some(request) = restart_task_rx.recv() => { + let (incoming_resp_tx, incoming_resp_rx) = tokio::sync::oneshot::channel(); + let fut = async move { + let p2p::Response::RestartTask(resp) = incoming_resp_rx.await.context("outgoing request tx channel was dropped")? else { + bail!("unexpected response type for restart task request"); + }; + request.response_tx.send(resp).map_err(|_|anyhow::anyhow!("caller dropped response channel"))?; + Ok(()) + }.boxed(); + futures.push(fut); + + let outgoing_request = OutgoingRequest { + peer_wallet_address: request.worker_wallet_address, + peer_id: request.worker_p2p_id, + multiaddrs: request.worker_addresses, + request: p2p::Request::RestartTask, + response_tx: incoming_resp_tx, + }; + outgoing_message_tx.send(outgoing_request).await + .context("failed to send outgoing restart task request")?; + } + Some(res) = futures.next() => { + if let Err(e) = res { + log::error!("failed to handle response conversion: {e}"); + } + } + } + } + } +} + +pub struct InviteRequest { + pub(crate) worker_wallet_address: alloy::primitives::Address, + pub(crate) worker_p2p_id: String, + pub(crate) worker_addresses: Vec, + pub(crate) invite: p2p::InviteRequest, + pub(crate) response_tx: tokio::sync::oneshot::Sender, +} + +pub struct GetTaskLogsRequest { + pub(crate) worker_wallet_address: alloy::primitives::Address, + pub(crate) worker_p2p_id: String, + pub(crate) worker_addresses: Vec, + pub(crate) response_tx: tokio::sync::oneshot::Sender, +} + +pub struct RestartTaskRequest { + pub(crate) worker_wallet_address: alloy::primitives::Address, + pub(crate) worker_p2p_id: String, + pub(crate) worker_addresses: Vec, + pub(crate) response_tx: tokio::sync::oneshot::Sender, +} diff --git a/crates/orchestrator/src/plugins/node_groups/tests.rs b/crates/orchestrator/src/plugins/node_groups/tests.rs index a7d73b36..5fc22430 100644 --- a/crates/orchestrator/src/plugins/node_groups/tests.rs +++ b/crates/orchestrator/src/plugins/node_groups/tests.rs @@ -276,9 +276,7 @@ async fn test_group_formation_with_multiple_configs() { let _ = plugin.try_form_new_groups().await; let mut conn = plugin.store.client.get_connection().unwrap(); - let groups: Vec = conn - .keys(format!("{}*", GROUP_KEY_PREFIX).as_str()) - .unwrap(); + let groups: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*").as_str()).unwrap(); assert_eq!(groups.len(), 2); // Verify group was created @@ -1102,7 +1100,7 @@ async fn test_node_cannot_be_in_multiple_groups() { ); // Get all group keys - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); let group_copy = group_keys.clone(); // There should be exactly one group @@ -1167,7 +1165,7 @@ async fn test_node_cannot_be_in_multiple_groups() { let _ = plugin.try_form_new_groups().await; // Get updated group keys - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); // There should now be exactly two groups assert_eq!( @@ -1544,7 +1542,7 @@ async fn test_task_observer() { let _ = store_context.task_store.add_task(task2.clone()).await; let _ = plugin.try_form_new_groups().await; let all_tasks = store_context.task_store.get_all_tasks().await.unwrap(); - println!("All tasks: {:?}", all_tasks); + println!("All tasks: {all_tasks:?}"); assert_eq!(all_tasks.len(), 2); assert!(all_tasks[0].id != all_tasks[1].id); let topologies = get_task_topologies(&task).unwrap(); @@ -1588,7 +1586,7 @@ async fn test_task_observer() { .unwrap(); assert!(group_3.is_some()); let all_tasks = store_context.task_store.get_all_tasks().await.unwrap(); - println!("All tasks: {:?}", all_tasks); + println!("All tasks: {all_tasks:?}"); assert_eq!(all_tasks.len(), 2); // Manually assign the first task to the group to test immediate dissolution let group_3_before = plugin @@ -1615,7 +1613,7 @@ async fn test_task_observer() { .get_node_group(&node_3.address.to_string()) .await .unwrap(); - println!("Group 3 after task deletion: {:?}", group_3); + println!("Group 3 after task deletion: {group_3:?}"); // With new behavior, group should be dissolved immediately when its assigned task is deleted assert!(group_3.is_none()); @@ -1833,7 +1831,7 @@ async fn test_group_formation_priority() { let nodes: Vec<_> = (1..=4) .map(|i| { create_test_node( - &format!("0x{}234567890123456789012345678901234567890", i), + &format!("0x{i}234567890123456789012345678901234567890"), NodeStatus::Healthy, None, ) @@ -1863,7 +1861,7 @@ async fn test_group_formation_priority() { // Verify: Should form one 3-node group + one 1-node group // NOT four 1-node groups let mut conn = plugin.store.client.get_connection().unwrap(); - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); assert_eq!(group_keys.len(), 2, "Should form exactly 2 groups"); // Check group compositions @@ -1944,7 +1942,7 @@ async fn test_multiple_groups_same_configuration() { let nodes: Vec<_> = (1..=6) .map(|i| { create_test_node( - &format!("0x{}234567890123456789012345678901234567890", i), + &format!("0x{i}234567890123456789012345678901234567890"), NodeStatus::Healthy, None, ) @@ -1958,7 +1956,7 @@ async fn test_multiple_groups_same_configuration() { // Verify: Should create 3 groups of 2 nodes each let mut conn = plugin.store.client.get_connection().unwrap(); - let group_keys: Vec = conn.keys(format!("{}*", GROUP_KEY_PREFIX)).unwrap(); + let group_keys: Vec = conn.keys(format!("{GROUP_KEY_PREFIX}*")).unwrap(); assert_eq!(group_keys.len(), 3, "Should form exactly 3 groups"); // Verify all groups have exactly 2 nodes and same configuration @@ -2663,7 +2661,7 @@ async fn test_no_merge_when_policy_disabled() { // Create 3 nodes let nodes: Vec<_> = (1..=3) - .map(|i| create_test_node(&format!("0x{:040x}", i), NodeStatus::Healthy, None)) + .map(|i| create_test_node(&format!("0x{i:040x}"), NodeStatus::Healthy, None)) .collect(); for node in &nodes { diff --git a/crates/orchestrator/src/scheduler/mod.rs b/crates/orchestrator/src/scheduler/mod.rs index 711f313f..d5ffa506 100644 --- a/crates/orchestrator/src/scheduler/mod.rs +++ b/crates/orchestrator/src/scheduler/mod.rs @@ -144,12 +144,12 @@ mod tests { ); assert_eq!( env_vars.get("NODE_VAR").unwrap(), - &format!("node-{}", node_address) + &format!("node-{node_address}") ); // Check cmd replacement let cmd = returned_task.cmd.unwrap(); assert_eq!(cmd[0], format!("--task={}", task.id)); - assert_eq!(cmd[1], format!("--node={}", node_address)); + assert_eq!(cmd[1], format!("--node={node_address}")); } } diff --git a/crates/orchestrator/src/status_update/mod.rs b/crates/orchestrator/src/status_update/mod.rs index b2738488..67140cbc 100644 --- a/crates/orchestrator/src/status_update/mod.rs +++ b/crates/orchestrator/src/status_update/mod.rs @@ -372,6 +372,7 @@ async fn process_node( } #[cfg(test)] +#[allow(clippy::unused_async)] async fn is_node_in_pool(_: Contracts, _: u32, _: &OrchestratorNode) -> bool { true } @@ -433,7 +434,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let heartbeat = HeartbeatRequest { address: node.address.to_string(), @@ -451,7 +452,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } let _ = updater.process_nodes().await; @@ -510,7 +511,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -563,7 +564,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -623,7 +624,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } if let Err(e) = app_state .store_context @@ -631,7 +632,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); } let mode = ServerMode::Full; @@ -687,7 +688,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); }; let heartbeat = HeartbeatRequest { @@ -702,7 +703,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } if let Err(e) = app_state .store_context @@ -710,7 +711,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; @@ -772,7 +773,7 @@ mod tests { .set_unhealthy_counter(&node1.address, 1) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); }; if let Err(e) = app_state .store_context @@ -780,7 +781,7 @@ mod tests { .add_node(node1.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let node2 = OrchestratorNode { @@ -797,7 +798,7 @@ mod tests { .add_node(node2.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; @@ -873,7 +874,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } if let Err(e) = app_state .store_context @@ -881,7 +882,7 @@ mod tests { .set_unhealthy_counter(&node.address, 2) .await { - error!("Error setting unhealthy counter: {}", e); + error!("Error setting unhealthy counter: {e}"); } let mode = ServerMode::Full; @@ -926,7 +927,7 @@ mod tests { .beat(&heartbeat) .await { - error!("Heartbeat Error: {}", e); + error!("Heartbeat Error: {e}"); } sleep(Duration::from_secs(5)).await; @@ -960,7 +961,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( @@ -1029,7 +1030,7 @@ mod tests { .add_node(node.clone()) .await { - error!("Error adding node: {}", e); + error!("Error adding node: {e}"); } let mode = ServerMode::Full; let updater = NodeStatusUpdater::new( diff --git a/crates/orchestrator/src/store/core/redis.rs b/crates/orchestrator/src/store/core/redis.rs index 79f57ce8..3b524b33 100644 --- a/crates/orchestrator/src/store/core/redis.rs +++ b/crates/orchestrator/src/store/core/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/orchestrator/src/store/domains/heartbeat_store.rs b/crates/orchestrator/src/store/domains/heartbeat_store.rs index b2f8138a..8bb43374 100644 --- a/crates/orchestrator/src/store/domains/heartbeat_store.rs +++ b/crates/orchestrator/src/store/domains/heartbeat_store.rs @@ -80,7 +80,7 @@ impl HeartbeatStore { .get_multiplexed_async_connection() .await .map_err(|_| anyhow!("Failed to get connection"))?; - let key = format!("{}:{}", ORCHESTRATOR_UNHEALTHY_COUNTER_KEY, address); + let key = format!("{ORCHESTRATOR_UNHEALTHY_COUNTER_KEY}:{address}"); con.set(key, counter.to_string()) .await .map_err(|_| anyhow!("Failed to set value")) diff --git a/crates/orchestrator/src/store/domains/metrics_store.rs b/crates/orchestrator/src/store/domains/metrics_store.rs index 1a0d79ac..5520860a 100644 --- a/crates/orchestrator/src/store/domains/metrics_store.rs +++ b/crates/orchestrator/src/store/domains/metrics_store.rs @@ -145,7 +145,7 @@ impl MetricsStore { task_id: &str, ) -> Result> { let mut con = self.redis.client.get_multiplexed_async_connection().await?; - let pattern = format!("{}:*", ORCHESTRATOR_NODE_METRICS_STORE); + let pattern = format!("{ORCHESTRATOR_NODE_METRICS_STORE}:*"); // Scan all node keys let mut iter: redis::AsyncIter = con.scan_match(&pattern).await?; diff --git a/crates/p2p/Cargo.toml b/crates/p2p/Cargo.toml new file mode 100644 index 00000000..498fbd29 --- /dev/null +++ b/crates/p2p/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "p2p" +version.workspace = true +edition.workspace = true + +[dependencies] +libp2p = { version = "0.54", features = ["request-response", "identify", "ping", "mdns", "noise", "tcp", "autonat", "kad", "tokio", "cbor", "macros", "yamux"] } +void = "1.0" + +anyhow = {workspace = true} +nalgebra = {workspace = true} +serde = {workspace = true} +tokio = {workspace = true, features = ["sync"]} +tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } +log = { workspace = true } + +[lints] +workspace = true diff --git a/crates/p2p/src/behaviour.rs b/crates/p2p/src/behaviour.rs new file mode 100644 index 00000000..399693b5 --- /dev/null +++ b/crates/p2p/src/behaviour.rs @@ -0,0 +1,186 @@ +use anyhow::Context as _; +use anyhow::Result; +use libp2p::autonat; +use libp2p::connection_limits; +use libp2p::connection_limits::ConnectionLimits; +use libp2p::identify; +use libp2p::identity; +use libp2p::kad; +// use libp2p::kad::store::MemoryStore; +use libp2p::mdns; +use libp2p::ping; +use libp2p::request_response; +use libp2p::swarm::NetworkBehaviour; +use log::debug; +use std::time::Duration; + +use crate::message::IncomingMessage; +use crate::message::{Request, Response}; +use crate::Protocols; +use crate::PRIME_STREAM_PROTOCOL; + +#[derive(NetworkBehaviour)] +#[behaviour(to_swarm = "BehaviourEvent")] +pub(crate) struct Behaviour { + // connection gating + connection_limits: connection_limits::Behaviour, + + // discovery + mdns: mdns::tokio::Behaviour, + // comment out kademlia for now as it requires bootnodes to be provided + // kademlia: kad::Behaviour, + + // protocols + identify: identify::Behaviour, + ping: ping::Behaviour, + request_response: request_response::cbor::Behaviour, + + // nat traversal + autonat: autonat::Behaviour, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum BehaviourEvent { + Autonat(autonat::Event), + Identify(identify::Event), + Kademlia(kad::Event), + Mdns(mdns::Event), + Ping(ping::Event), + RequestResponse(request_response::Event), +} + +impl From for BehaviourEvent { + fn from(_: void::Void) -> Self { + unreachable!("void::Void cannot be converted to BehaviourEvent") + } +} + +impl From for BehaviourEvent { + fn from(event: autonat::Event) -> Self { + BehaviourEvent::Autonat(event) + } +} + +impl From for BehaviourEvent { + fn from(event: kad::Event) -> Self { + BehaviourEvent::Kademlia(event) + } +} + +impl From for BehaviourEvent { + fn from(event: libp2p::mdns::Event) -> Self { + BehaviourEvent::Mdns(event) + } +} + +impl From for BehaviourEvent { + fn from(event: ping::Event) -> Self { + BehaviourEvent::Ping(event) + } +} + +impl From for BehaviourEvent { + fn from(event: identify::Event) -> Self { + BehaviourEvent::Identify(event) + } +} + +impl From> for BehaviourEvent { + fn from(event: request_response::Event) -> Self { + BehaviourEvent::RequestResponse(event) + } +} + +impl Behaviour { + pub(crate) fn new( + keypair: &identity::Keypair, + protocols: Protocols, + agent_version: String, + ) -> Result { + let peer_id = keypair.public().to_peer_id(); + + let protocols = protocols.into_iter().map(|protocol| { + ( + protocol.as_stream_protocol(), + request_response::ProtocolSupport::Full, // TODO: configure inbound/outbound based on node role and protocol + ) + }); + + let autonat = autonat::Behaviour::new(peer_id, autonat::Config::default()); + let connection_limits = connection_limits::Behaviour::new( + ConnectionLimits::default().with_max_established(Some(100)), + ); + + let mdns = mdns::tokio::Behaviour::new(mdns::Config::default(), peer_id) + .context("failed to create mDNS behaviour")?; + // let kademlia = kad::Behaviour::new(peer_id, MemoryStore::new(peer_id)); + + let identify = identify::Behaviour::new( + identify::Config::new(PRIME_STREAM_PROTOCOL.to_string(), keypair.public()) + .with_agent_version(agent_version), + ); + let ping = ping::Behaviour::new(ping::Config::new().with_interval(Duration::from_secs(10))); + + Ok(Self { + autonat, + connection_limits, + // kademlia, + mdns, + identify, + ping, + request_response: request_response::cbor::Behaviour::new( + protocols, + request_response::Config::default(), + ), + }) + } + + pub(crate) fn request_response( + &mut self, + ) -> &mut request_response::cbor::Behaviour { + &mut self.request_response + } +} + +impl BehaviourEvent { + pub(crate) async fn handle(self, message_tx: tokio::sync::mpsc::Sender) { + match self { + BehaviourEvent::Autonat(_event) => {} + BehaviourEvent::Identify(_event) => {} + BehaviourEvent::Kademlia(_event) => { // TODO: potentially on outbound queries + } + BehaviourEvent::Mdns(_event) => {} + BehaviourEvent::Ping(_event) => {} + BehaviourEvent::RequestResponse(event) => match event { + request_response::Event::Message { peer, message } => { + debug!("received message from peer {peer:?}: {message:?}"); + + // if this errors, user dropped their incoming message channel + let _ = message_tx.send(IncomingMessage { peer, message }).await; + } + request_response::Event::ResponseSent { peer, request_id } => { + debug!("response sent to peer {peer:?} for request ID {request_id:?}"); + } + request_response::Event::InboundFailure { + peer, + request_id, + error, + } => { + debug!( + "inbound failure from peer {peer:?} for request ID {request_id:?}: {error}" + ); + } + request_response::Event::OutboundFailure { + peer, + request_id, + error, + } => { + debug!( + "outbound failure to peer {peer:?} for request ID {request_id:?}: {error}" + ); + } + }, + } + } +} diff --git a/crates/p2p/src/lib.rs b/crates/p2p/src/lib.rs new file mode 100644 index 00000000..f5bc648c --- /dev/null +++ b/crates/p2p/src/lib.rs @@ -0,0 +1,399 @@ +use anyhow::Context; +use anyhow::Result; +use libp2p::noise; +use libp2p::swarm::SwarmEvent; +use libp2p::tcp; +use libp2p::yamux; +use libp2p::Swarm; +use libp2p::SwarmBuilder; +use libp2p::{identity, Transport}; +use log::debug; +use std::time::Duration; + +mod behaviour; +mod message; +mod protocol; + +use behaviour::Behaviour; + +pub use message::*; +pub use protocol::*; + +pub type Libp2pIncomingMessage = libp2p::request_response::Message; +pub type ResponseChannel = libp2p::request_response::ResponseChannel; +pub type PeerId = libp2p::PeerId; +pub type Multiaddr = libp2p::Multiaddr; +pub type Keypair = libp2p::identity::Keypair; + +pub const PRIME_STREAM_PROTOCOL: libp2p::StreamProtocol = + libp2p::StreamProtocol::new("/prime/1.0.0"); +// TODO: force this to be passed by the user +pub const DEFAULT_AGENT_VERSION: &str = "prime-node/0.1.0"; + +pub struct Node { + peer_id: PeerId, + listen_addrs: Vec, + swarm: Swarm, + bootnodes: Vec, + cancellation_token: tokio_util::sync::CancellationToken, + + // channel for sending incoming messages to the consumer of this library + incoming_message_tx: tokio::sync::mpsc::Sender, + + // channel for receiving outgoing messages from the consumer of this library + outgoing_message_rx: tokio::sync::mpsc::Receiver, +} + +impl Node { + pub fn peer_id(&self) -> PeerId { + self.peer_id + } + + pub fn listen_addrs(&self) -> &[libp2p::Multiaddr] { + &self.listen_addrs + } + + /// Returns the multiaddresses that this node is listening on, with the peer ID included. + pub fn multiaddrs(&self) -> Vec { + self.listen_addrs + .iter() + .map(|addr| { + addr.clone() + .with_p2p(self.peer_id) + .expect("can add peer ID to multiaddr") + }) + .collect() + } + + pub async fn run(self) -> Result<()> { + use libp2p::futures::StreamExt as _; + + let Node { + peer_id: _, + listen_addrs, + mut swarm, + bootnodes, + cancellation_token, + incoming_message_tx, + mut outgoing_message_rx, + } = self; + + for addr in listen_addrs { + swarm + .listen_on(addr) + .context("swarm failed to listen on multiaddr")?; + } + + for bootnode in bootnodes { + match swarm.dial(bootnode.clone()) { + Ok(_) => {} + Err(e) => { + debug!("failed to dial bootnode {bootnode}: {e:?}"); + } + } + } + + loop { + tokio::select! { + biased; + _ = cancellation_token.cancelled() => { + debug!("cancellation token triggered, shutting down node"); + break Ok(()); + } + Some(message) = outgoing_message_rx.recv() => { + match message { + OutgoingMessage::Request((peer, addrs, request)) => { + // TODO: if we're not connected to the peer, we should dial it + for addr in addrs { + swarm.add_peer_address(peer, addr); + } + swarm.behaviour_mut().request_response().send_request(&peer, request); + } + OutgoingMessage::Response((channel, response)) => { + if let Err(e) = swarm.behaviour_mut().request_response().send_response(channel, response) { + debug!("failed to send response: {e:?}"); + } + } + } + } + event = swarm.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { + address, + .. + } => { + debug!("new listen address: {address}"); + } + SwarmEvent::ExternalAddrConfirmed { address } => { + debug!("external address confirmed: {address}"); + } + SwarmEvent::ConnectionEstablished { + peer_id, + .. + } => { + debug!("connection established with peer {peer_id}"); + } + SwarmEvent::ConnectionClosed { + peer_id, + cause, + .. + } => { + debug!("connection closed with peer {peer_id}: {cause:?}"); + } + SwarmEvent::Behaviour(event) => event.handle(incoming_message_tx.clone()).await, + _ => continue, + } + }, + } + } + } +} + +pub struct NodeBuilder { + port: Option, + listen_addrs: Vec, + keypair: Option, + agent_version: Option, + protocols: Protocols, + bootnodes: Vec, + cancellation_token: Option, +} + +impl Default for NodeBuilder { + fn default() -> Self { + Self::new() + } +} + +impl NodeBuilder { + pub fn new() -> Self { + Self { + port: None, + listen_addrs: Vec::new(), + keypair: None, + agent_version: None, + protocols: Protocols::new(), + bootnodes: Vec::new(), + cancellation_token: None, + } + } + + pub fn with_port(mut self, port: u16) -> Self { + self.port = Some(port); + self + } + + pub fn with_listen_addr(mut self, addr: libp2p::Multiaddr) -> Self { + self.listen_addrs.push(addr); + self + } + + pub fn with_keypair(mut self, keypair: identity::Keypair) -> Self { + self.keypair = Some(keypair); + self + } + + pub fn with_agent_version(mut self, agent_version: String) -> Self { + self.agent_version = Some(agent_version); + self + } + + pub fn with_authentication(mut self) -> Self { + self.protocols = self.protocols.with_authentication(); + self + } + + pub fn with_hardware_challenge(mut self) -> Self { + self.protocols = self.protocols.with_hardware_challenge(); + self + } + + pub fn with_invite(mut self) -> Self { + self.protocols = self.protocols.with_invite(); + self + } + + pub fn with_get_task_logs(mut self) -> Self { + self.protocols = self.protocols.with_get_task_logs(); + self + } + + pub fn with_restart(mut self) -> Self { + self.protocols = self.protocols.with_restart(); + self + } + + pub fn with_general(mut self) -> Self { + self.protocols = self.protocols.with_general(); + self + } + + pub fn with_protocols(mut self, protocols: Protocols) -> Self { + self.protocols.join(protocols); + self + } + + pub fn with_bootnode(mut self, bootnode: Multiaddr) -> Self { + self.bootnodes.push(bootnode); + self + } + + pub fn with_bootnodes(mut self, bootnodes: I) -> Self + where + I: IntoIterator, + T: Into, + { + for bootnode in bootnodes { + self.bootnodes.push(bootnode.into()); + } + self + } + + pub fn with_cancellation_token( + mut self, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> Self { + self.cancellation_token = Some(cancellation_token); + self + } + + pub fn try_build( + self, + ) -> Result<( + Node, + tokio::sync::mpsc::Receiver, + tokio::sync::mpsc::Sender, + )> { + let Self { + port, + mut listen_addrs, + keypair, + agent_version, + protocols, + bootnodes, + cancellation_token, + } = self; + + let keypair = keypair.unwrap_or(identity::Keypair::generate_ed25519()); + let peer_id = keypair.public().to_peer_id(); + + let transport = create_transport(&keypair)?; + let behaviour = Behaviour::new( + &keypair, + protocols, + agent_version.unwrap_or(DEFAULT_AGENT_VERSION.to_string()), + ) + .context("failed to create behaviour")?; + + let swarm = SwarmBuilder::with_existing_identity(keypair) + .with_tokio() + .with_other_transport(|_| transport)? + .with_behaviour(|_| behaviour)? + .with_swarm_config(|cfg| { + cfg.with_idle_connection_timeout(Duration::from_secs(u64::MAX)) // don't disconnect from idle peers + }) + .build(); + + if listen_addrs.is_empty() { + let port = port.unwrap_or(0); + let listen_addr = format!("/ip4/0.0.0.0/tcp/{port}") + .parse() + .expect("can parse valid multiaddr"); + listen_addrs.push(listen_addr); + } + + let (incoming_message_tx, incoming_message_rx) = tokio::sync::mpsc::channel(100); + let (outgoing_message_tx, outgoing_message_rx) = tokio::sync::mpsc::channel(100); + + Ok(( + Node { + peer_id, + swarm, + listen_addrs, + bootnodes, + incoming_message_tx, + outgoing_message_rx, + cancellation_token: cancellation_token.unwrap_or_default(), + }, + incoming_message_rx, + outgoing_message_tx, + )) + } +} + +fn create_transport( + keypair: &identity::Keypair, +) -> Result> { + let transport = tcp::tokio::Transport::new(tcp::Config::default()) + .upgrade(libp2p::core::upgrade::Version::V1) + .authenticate(noise::Config::new(keypair)?) + .multiplex(yamux::Config::default()) + .timeout(Duration::from_secs(20)) + .boxed(); + + Ok(transport) +} + +#[cfg(test)] +mod test { + use super::NodeBuilder; + use crate::message; + + #[tokio::test] + async fn two_nodes_can_connect_and_do_request_response() { + let (node1, mut incoming_message_rx1, outgoing_message_tx1) = + NodeBuilder::new().with_get_task_logs().try_build().unwrap(); + let node1_peer_id = node1.peer_id(); + + let (node2, mut incoming_message_rx2, outgoing_message_tx2) = NodeBuilder::new() + .with_get_task_logs() + .with_bootnodes(node1.multiaddrs()) + .try_build() + .unwrap(); + let node2_peer_id = node2.peer_id(); + + tokio::spawn(async move { node1.run().await }); + tokio::spawn(async move { node2.run().await }); + + // TODO: implement a way to get peer count + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + // send request from node1->node2 + let request = message::Request::GetTaskLogs; + outgoing_message_tx1 + .send(request.into_outgoing_message(node2_peer_id, vec![])) + .await + .unwrap(); + let message = incoming_message_rx2.recv().await.unwrap(); + assert_eq!(message.peer, node1_peer_id); + let libp2p::request_response::Message::Request { + request_id: _, + request: message::Request::GetTaskLogs, + channel, + } = message.message + else { + panic!("expected a GetTaskLogs request message"); + }; + + // send response from node2->node1 + let response = + message::Response::GetTaskLogs(message::GetTaskLogsResponse::Ok("logs".to_string())); + outgoing_message_tx2 + .send(response.into_outgoing_message(channel)) + .await + .unwrap(); + let message = incoming_message_rx1.recv().await.unwrap(); + assert_eq!(message.peer, node2_peer_id); + let libp2p::request_response::Message::Response { + request_id: _, + response: message::Response::GetTaskLogs(response), + } = message.message + else { + panic!("expected a GetTaskLogs response message"); + }; + let message::GetTaskLogsResponse::Ok(logs) = response else { + panic!("expected a successful GetTaskLogs response"); + }; + assert_eq!(logs, "logs"); + } +} diff --git a/crates/shared/src/models/challenge.rs b/crates/p2p/src/message/hardware_challenge.rs similarity index 100% rename from crates/shared/src/models/challenge.rs rename to crates/p2p/src/message/hardware_challenge.rs diff --git a/crates/p2p/src/message/mod.rs b/crates/p2p/src/message/mod.rs new file mode 100644 index 00000000..74b09c5a --- /dev/null +++ b/crates/p2p/src/message/mod.rs @@ -0,0 +1,250 @@ +use crate::Protocol; +use libp2p::PeerId; +use serde::{Deserialize, Serialize}; +use std::time::SystemTime; + +mod hardware_challenge; + +pub use hardware_challenge::*; + +#[derive(Debug)] +pub struct IncomingMessage { + pub peer: PeerId, + pub message: libp2p::request_response::Message, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub enum OutgoingMessage { + Request((PeerId, Vec, Request)), + Response( + ( + libp2p::request_response::ResponseChannel, + Response, + ), + ), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Request { + Authentication(AuthenticationRequest), + HardwareChallenge(HardwareChallengeRequest), + Invite(InviteRequest), + GetTaskLogs, + RestartTask, + General(GeneralRequest), +} + +impl Request { + pub fn into_outgoing_message( + self, + peer: PeerId, + multiaddrs: Vec, + ) -> OutgoingMessage { + OutgoingMessage::Request((peer, multiaddrs, self)) + } + + pub fn protocol(&self) -> Protocol { + match self { + Request::Authentication(_) => Protocol::Authentication, + Request::HardwareChallenge(_) => Protocol::HardwareChallenge, + Request::Invite(_) => Protocol::Invite, + Request::GetTaskLogs => Protocol::GetTaskLogs, + Request::RestartTask => Protocol::Restart, + Request::General(_) => Protocol::General, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Response { + Authentication(AuthenticationResponse), + HardwareChallenge(HardwareChallengeResponse), + Invite(InviteResponse), + GetTaskLogs(GetTaskLogsResponse), + RestartTask(RestartTaskResponse), + General(GeneralResponse), +} + +impl Response { + pub fn into_outgoing_message( + self, + channel: libp2p::request_response::ResponseChannel, + ) -> OutgoingMessage { + OutgoingMessage::Response((channel, self)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum AuthenticationRequest { + Initiation(AuthenticationInitiationRequest), + Solution(AuthenticationSolutionRequest), +} + +impl From for Request { + fn from(request: AuthenticationRequest) -> Self { + Request::Authentication(request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum AuthenticationResponse { + Initiation(AuthenticationInitiationResponse), + Solution(AuthenticationSolutionResponse), +} + +impl From for Response { + fn from(response: AuthenticationResponse) -> Self { + Response::Authentication(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthenticationInitiationRequest { + pub message: String, +} + +impl From for Request { + fn from(request: AuthenticationInitiationRequest) -> Self { + Request::Authentication(AuthenticationRequest::Initiation(request)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthenticationInitiationResponse { + pub signature: String, + pub message: String, +} + +impl From for Response { + fn from(response: AuthenticationInitiationResponse) -> Self { + Response::Authentication(AuthenticationResponse::Initiation(response)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthenticationSolutionRequest { + pub signature: String, +} + +impl From for Request { + fn from(request: AuthenticationSolutionRequest) -> Self { + Request::Authentication(AuthenticationRequest::Solution(request)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum AuthenticationSolutionResponse { + Granted, + Rejected, +} + +impl From for Response { + fn from(response: AuthenticationSolutionResponse) -> Self { + Response::Authentication(AuthenticationResponse::Solution(response)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HardwareChallengeRequest { + pub challenge: ChallengeRequest, + pub timestamp: SystemTime, +} + +impl From for Request { + fn from(request: HardwareChallengeRequest) -> Self { + Request::HardwareChallenge(request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HardwareChallengeResponse { + pub response: ChallengeResponse, + pub timestamp: SystemTime, +} + +impl From for Response { + fn from(response: HardwareChallengeResponse) -> Self { + Response::HardwareChallenge(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum InviteRequestUrl { + MasterUrl(String), + MasterIpPort(String, u16), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteRequest { + pub invite: String, + pub pool_id: u32, + pub url: InviteRequestUrl, + pub timestamp: u64, + pub expiration: [u8; 32], + pub nonce: [u8; 32], +} + +impl From for Request { + fn from(request: InviteRequest) -> Self { + Request::Invite(request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum InviteResponse { + Ok, + Error(String), +} + +impl From for Response { + fn from(response: InviteResponse) -> Self { + Response::Invite(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum GetTaskLogsResponse { + Ok(String), + Error(String), +} + +impl From for Response { + fn from(response: GetTaskLogsResponse) -> Self { + Response::GetTaskLogs(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum RestartTaskResponse { + Ok, + Error(String), +} + +impl From for Response { + fn from(response: RestartTaskResponse) -> Self { + Response::RestartTask(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeneralRequest { + data: Vec, +} + +impl From for Request { + fn from(request: GeneralRequest) -> Self { + Request::General(request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeneralResponse { + data: Vec, +} + +impl From for Response { + fn from(response: GeneralResponse) -> Self { + Response::General(response) + } +} diff --git a/crates/p2p/src/protocol.rs b/crates/p2p/src/protocol.rs new file mode 100644 index 00000000..f721bea6 --- /dev/null +++ b/crates/p2p/src/protocol.rs @@ -0,0 +1,113 @@ +use libp2p::StreamProtocol; +use std::{collections::HashSet, hash::Hash}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Protocol { + // validator or orchestrator -> worker + Authentication, + // validator -> worker + HardwareChallenge, + // orchestrator -> worker + Invite, + // any -> worker + GetTaskLogs, + // any -> worker + Restart, + // any -> any + General, +} + +impl Protocol { + pub(crate) fn as_stream_protocol(&self) -> StreamProtocol { + match self { + Protocol::Authentication => StreamProtocol::new("/prime/authentication/1.0.0"), + Protocol::HardwareChallenge => StreamProtocol::new("/prime/hardware_challenge/1.0.0"), + Protocol::Invite => StreamProtocol::new("/prime/invite/1.0.0"), + Protocol::GetTaskLogs => StreamProtocol::new("/prime/get_task_logs/1.0.0"), + Protocol::Restart => StreamProtocol::new("/prime/restart/1.0.0"), + Protocol::General => StreamProtocol::new("/prime/general/1.0.0"), + } + } +} + +#[derive(Debug, Clone)] +pub struct Protocols(HashSet); + +impl Default for Protocols { + fn default() -> Self { + Self::new() + } +} + +impl Protocols { + pub fn new() -> Self { + Self(HashSet::new()) + } + + pub fn has_authentication(&self) -> bool { + self.0.contains(&Protocol::Authentication) + } + + pub fn has_hardware_challenge(&self) -> bool { + self.0.contains(&Protocol::HardwareChallenge) + } + + pub fn has_invite(&self) -> bool { + self.0.contains(&Protocol::Invite) + } + + pub fn has_get_task_logs(&self) -> bool { + self.0.contains(&Protocol::GetTaskLogs) + } + + pub fn has_restart(&self) -> bool { + self.0.contains(&Protocol::Restart) + } + + pub fn has_general(&self) -> bool { + self.0.contains(&Protocol::General) + } + + pub fn with_authentication(mut self) -> Self { + self.0.insert(Protocol::Authentication); + self + } + + pub fn with_hardware_challenge(mut self) -> Self { + self.0.insert(Protocol::HardwareChallenge); + self + } + + pub fn with_invite(mut self) -> Self { + self.0.insert(Protocol::Invite); + self + } + + pub fn with_get_task_logs(mut self) -> Self { + self.0.insert(Protocol::GetTaskLogs); + self + } + + pub fn with_restart(mut self) -> Self { + self.0.insert(Protocol::Restart); + self + } + + pub fn with_general(mut self) -> Self { + self.0.insert(Protocol::General); + self + } + + pub(crate) fn join(&mut self, other: Protocols) { + self.0.extend(other.0); + } +} + +impl IntoIterator for Protocols { + type Item = Protocol; + type IntoIter = std::collections::hash_set::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} diff --git a/crates/prime-core/Cargo.toml b/crates/prime-core/Cargo.toml new file mode 100644 index 00000000..bfcef45e --- /dev/null +++ b/crates/prime-core/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "prime-core" +version = "0.1.0" +edition = "2021" + +[lints] +workspace = true + +[lib] +name = "prime_core" +path = "src/lib.rs" + +[dependencies] +shared = { workspace = true } +alloy = { workspace = true } +alloy-provider = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +url = { workspace = true } +actix-web = { workspace = true } +anyhow = { workspace = true } +futures-util = { workspace = true } +hex = { workspace = true } +uuid = { workspace = true } +log = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +redis = { workspace = true, features = ["aio", "tokio-comp"] } +rand_v8 = { workspace = true } +env_logger = { workspace = true } +subtle = "2.6.1" diff --git a/crates/prime-core/src/lib.rs b/crates/prime-core/src/lib.rs new file mode 100644 index 00000000..1bf04f8a --- /dev/null +++ b/crates/prime-core/src/lib.rs @@ -0,0 +1 @@ +pub mod operations; diff --git a/crates/prime-core/src/operations/compute_node.rs b/crates/prime-core/src/operations/compute_node.rs new file mode 100644 index 00000000..c294291a --- /dev/null +++ b/crates/prime-core/src/operations/compute_node.rs @@ -0,0 +1,92 @@ +use alloy::{primitives::utils::keccak256 as keccak, primitives::U256, signers::Signer}; +use anyhow::Result; +use shared::web3::wallet::Wallet; +use shared::web3::{contracts::core::builder::Contracts, wallet::WalletProvider}; + +pub struct ComputeNodeOperations<'c> { + provider_wallet: &'c Wallet, + node_wallet: &'c Wallet, + contracts: Contracts, +} + +impl<'c> ComputeNodeOperations<'c> { + pub fn new( + provider_wallet: &'c Wallet, + node_wallet: &'c Wallet, + contracts: Contracts, + ) -> Self { + Self { + provider_wallet, + node_wallet, + contracts, + } + } + + pub async fn check_compute_node_exists(&self) -> Result> { + let compute_node = self + .contracts + .compute_registry + .get_node( + self.provider_wallet.wallet.default_signer().address(), + self.node_wallet.wallet.default_signer().address(), + ) + .await; + + match compute_node { + Ok(_) => Ok(true), + Err(_) => Ok(false), + } + } + + // Returns true if the compute node was added, false if it already exists + pub async fn add_compute_node( + &self, + compute_units: U256, + ) -> Result> { + log::info!("🔄 Adding compute node"); + + if self.check_compute_node_exists().await? { + return Ok(false); + } + + log::info!("Adding compute node"); + let provider_address = self.provider_wallet.wallet.default_signer().address(); + let node_address = self.node_wallet.wallet.default_signer().address(); + let digest = keccak([provider_address.as_slice(), node_address.as_slice()].concat()); + + let signature = self + .node_wallet + .signer + .sign_message(digest.as_slice()) + .await? + .as_bytes(); + + // Create the signature bytes + let add_node_tx = self + .contracts + .prime_network + .add_compute_node(node_address, compute_units, signature.to_vec()) + .await?; + log::info!("Add node tx: {add_node_tx:?}"); + Ok(true) + } + + pub async fn remove_compute_node(&self) -> Result> { + log::info!("🔄 Removing compute node"); + + if !self.check_compute_node_exists().await? { + return Ok(false); + } + + log::info!("Removing compute node"); + let provider_address = self.provider_wallet.wallet.default_signer().address(); + let node_address = self.node_wallet.wallet.default_signer().address(); + let remove_node_tx = self + .contracts + .prime_network + .remove_compute_node(provider_address, node_address) + .await?; + log::info!("Remove node tx: {remove_node_tx:?}"); + Ok(true) + } +} diff --git a/crates/prime-core/src/operations/mod.rs b/crates/prime-core/src/operations/mod.rs new file mode 100644 index 00000000..089315f5 --- /dev/null +++ b/crates/prime-core/src/operations/mod.rs @@ -0,0 +1,2 @@ +pub mod compute_node; +pub mod provider; diff --git a/crates/worker/src/operations/provider.rs b/crates/prime-core/src/operations/provider.rs similarity index 67% rename from crates/worker/src/operations/provider.rs rename to crates/prime-core/src/operations/provider.rs index fb8aba5f..c07f6189 100644 --- a/crates/worker/src/operations/provider.rs +++ b/crates/prime-core/src/operations/provider.rs @@ -1,4 +1,3 @@ -use crate::console::Console; use alloy::primitives::utils::format_ether; use alloy::primitives::{Address, U256}; use log::error; @@ -9,18 +8,14 @@ use std::{fmt, io}; use tokio::time::{sleep, Duration}; use tokio_util::sync::CancellationToken; -pub(crate) struct ProviderOperations { +pub struct ProviderOperations { wallet: Wallet, contracts: Contracts, auto_accept: bool, } impl ProviderOperations { - pub(crate) fn new( - wallet: Wallet, - contracts: Contracts, - auto_accept: bool, - ) -> Self { + pub fn new(wallet: Wallet, contracts: Contracts, auto_accept: bool) -> Self { Self { wallet, contracts, @@ -44,7 +39,7 @@ impl ProviderOperations { } } - pub(crate) fn start_monitoring(&self, cancellation_token: CancellationToken) { + pub fn start_monitoring(&self, cancellation_token: CancellationToken) { let provider_address = self.wallet.wallet.default_signer().address(); let contracts = self.contracts.clone(); @@ -58,12 +53,12 @@ impl ProviderOperations { loop { tokio::select! { _ = cancellation_token.cancelled() => { - Console::info("Monitor", "Shutting down provider status monitor..."); + log::info!("Shutting down provider status monitor..."); break; } _ = async { let Some(stake_manager) = contracts.stake_manager.as_ref() else { - Console::user_error("Cannot start monitoring - stake manager not initialized"); + log::error!("Cannot start monitoring - stake manager not initialized"); return; }; @@ -71,21 +66,21 @@ impl ProviderOperations { match stake_manager.get_stake(provider_address).await { Ok(stake) => { if first_check || stake != last_stake { - Console::info("🔄 Chain Sync - Provider stake", &format_ether(stake)); + log::info!("🔄 Chain Sync - Provider stake: {}", format_ether(stake)); if !first_check { if stake < last_stake { - Console::warning(&format!("Stake decreased - possible slashing detected: From {} to {}", + log::warn!("Stake decreased - possible slashing detected: From {} to {}", format_ether(last_stake), format_ether(stake) - )); + ); if stake == U256::ZERO { - Console::warning("Stake is 0 - you might have to restart the node to increase your stake (if you still have balance left)"); + log::warn!("Stake is 0 - you might have to restart the node to increase your stake (if you still have balance left)"); } } else { - Console::info("🔄 Chain Sync - Stake changed", &format!("From {} to {}", + log::info!("🔄 Chain Sync - Stake increased: From {} to {}", format_ether(last_stake), format_ether(stake) - )); + ); } } last_stake = stake; @@ -102,13 +97,7 @@ impl ProviderOperations { match contracts.ai_token.balance_of(provider_address).await { Ok(balance) => { if first_check || balance != last_balance { - Console::info("🔄 Chain Sync - Balance", &format_ether(balance)); - if !first_check { - Console::info("🔄 Chain Sync - Balance changed", &format!("From {} to {}", - format_ether(last_balance), - format_ether(balance) - )); - } + log::info!("🔄 Chain Sync - Balance: {}", format_ether(balance)); last_balance = balance; } Some(balance) @@ -123,12 +112,12 @@ impl ProviderOperations { match contracts.compute_registry.get_provider(provider_address).await { Ok(provider) => { if first_check || provider.is_whitelisted != last_whitelist_status { - Console::info("🔄 Chain Sync - Whitelist status", &format!("{}", provider.is_whitelisted)); + log::info!("🔄 Chain Sync - Whitelist status: {}", provider.is_whitelisted); if !first_check { - Console::info("🔄 Chain Sync - Whitelist status changed", &format!("From {} to {}", + log::info!("🔄 Chain Sync - Whitelist status changed: {} -> {}", last_whitelist_status, provider.is_whitelisted - )); + ); } last_whitelist_status = provider.is_whitelisted; } @@ -146,7 +135,7 @@ impl ProviderOperations { }); } - pub(crate) async fn check_provider_exists(&self) -> Result { + pub async fn check_provider_exists(&self) -> Result { let address = self.wallet.wallet.default_signer().address(); let provider = self @@ -159,7 +148,7 @@ impl ProviderOperations { Ok(provider.provider_address != Address::default()) } - pub(crate) async fn check_provider_whitelisted(&self) -> Result { + pub async fn check_provider_whitelisted(&self) -> Result { let address = self.wallet.wallet.default_signer().address(); let provider = self @@ -171,29 +160,32 @@ impl ProviderOperations { Ok(provider.is_whitelisted) } - - pub(crate) async fn retry_register_provider( + pub async fn retry_register_provider( &self, stake: U256, max_attempts: u32, - cancellation_token: CancellationToken, + cancellation_token: Option, ) -> Result<(), ProviderError> { - Console::title("Registering Provider"); + log::info!("Registering Provider"); let mut attempts = 0; while attempts < max_attempts || max_attempts == 0 { - Console::progress("Registering provider..."); + log::info!("Registering provider..."); match self.register_provider(stake).await { Ok(_) => { return Ok(()); } Err(e) => match e { ProviderError::NotWhitelisted | ProviderError::InsufficientBalance => { - Console::info("Info", "Retrying in 10 seconds..."); - tokio::select! { - _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} - _ = cancellation_token.cancelled() => { - return Err(e); + log::info!("Retrying in 10 seconds..."); + if let Some(ref token) = cancellation_token { + tokio::select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => {} + _ = token.cancelled() => { + return Err(e); + } } + } else { + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; } attempts += 1; continue; @@ -206,7 +198,7 @@ impl ProviderOperations { Err(ProviderError::Other) } - pub(crate) async fn register_provider(&self, stake: U256) -> Result<(), ProviderError> { + pub async fn register_provider(&self, stake: U256) -> Result<(), ProviderError> { let address = self.wallet.wallet.default_signer().address(); let balance: U256 = self .contracts @@ -224,42 +216,39 @@ impl ProviderOperations { let provider_exists = self.check_provider_exists().await?; if !provider_exists { - Console::info("Balance", &format_ether(balance)); - Console::info( - "ETH Balance", + log::info!("Balance: {}", &format_ether(balance)); + log::info!( + "ETH Balance: {}", &format!("{} ETH", format_ether(U256::from(eth_balance))), ); if balance < stake { - Console::user_error(&format!( - "Insufficient balance for stake: {}", - format_ether(stake) - )); + log::error!("Insufficient balance for stake: {}", format_ether(stake)); return Err(ProviderError::InsufficientBalance); } if !self.prompt_user_confirmation(&format!( "Do you want to approve staking {}?", format_ether(stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving for Stake transaction"); + log::info!("Approving for Stake transaction"); self.contracts .ai_token .approve(stake) .await .map_err(|_| ProviderError::Other)?; - Console::progress("Registering Provider"); + log::info!("Registering Provider"); let Ok(register_tx) = self.contracts.prime_network.register_provider(stake).await else { return Err(ProviderError::Other); }; - Console::info("Registration tx", &format!("{register_tx:?}")); + log::info!("Registration tx: {}", &format!("{register_tx:?}")); } // Get provider details again - cleanup later - Console::progress("Getting provider details"); + log::info!("Getting provider details"); let _ = self .contracts .compute_registry @@ -270,32 +259,29 @@ impl ProviderOperations { let provider_exists = self.check_provider_exists().await?; if !provider_exists { - Console::info("Balance", &format_ether(balance)); - Console::info( - "ETH Balance", + log::info!("Balance: {}", &format_ether(balance)); + log::info!( + "ETH Balance: {}", &format!("{} ETH", format_ether(U256::from(eth_balance))), ); if balance < stake { - Console::user_error(&format!( - "Insufficient balance for stake: {}", - format_ether(stake) - )); + log::error!("Insufficient balance for stake: {}", format_ether(stake)); return Err(ProviderError::InsufficientBalance); } if !self.prompt_user_confirmation(&format!( "Do you want to approve staking {}?", format_ether(stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving Stake transaction"); + log::info!("Approving Stake transaction"); self.contracts.ai_token.approve(stake).await.map_err(|e| { error!("Failed to approve stake: {e}"); ProviderError::Other })?; - Console::progress("Registering Provider"); + log::info!("Registering Provider"); let register_tx = match self.contracts.prime_network.register_provider(stake).await { Ok(tx) => tx, Err(e) => { @@ -303,7 +289,7 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info("Registration tx", &format!("{register_tx:?}")); + log::info!("Registration tx: {register_tx:?}"); } let provider = self @@ -315,23 +301,23 @@ impl ProviderOperations { let provider_exists = provider.provider_address != Address::default(); if !provider_exists { - Console::user_error( - "Provider could not be registered. Please ensure your balance is high enough.", + log::error!( + "Provider could not be registered. Please ensure your balance is high enough." ); return Err(ProviderError::Other); } - Console::success("Provider registered"); + log::info!("Provider registered"); if !provider.is_whitelisted { - Console::user_error("Provider is not whitelisted yet."); + log::error!("Provider is not whitelisted yet."); return Err(ProviderError::NotWhitelisted); } Ok(()) } - pub(crate) async fn increase_stake(&self, additional_stake: U256) -> Result<(), ProviderError> { - Console::title("💰 Increasing Provider Stake"); + pub async fn increase_stake(&self, additional_stake: U256) -> Result<(), ProviderError> { + log::info!("💰 Increasing Provider Stake"); let address = self.wallet.wallet.default_signer().address(); let balance: U256 = self @@ -341,11 +327,14 @@ impl ProviderOperations { .await .map_err(|_| ProviderError::Other)?; - Console::info("Current Balance", &format_ether(balance)); - Console::info("Additional stake amount", &format_ether(additional_stake)); + log::info!("Current Balance: {}", &format_ether(balance)); + log::info!( + "Additional stake amount: {}", + &format_ether(additional_stake) + ); if balance < additional_stake { - Console::user_error("Insufficient balance for stake increase"); + log::error!("Insufficient balance for stake increase"); return Err(ProviderError::Other); } @@ -353,20 +342,20 @@ impl ProviderOperations { "Do you want to approve staking {} additional funds?", format_ether(additional_stake) )) { - Console::info("Operation cancelled by user", "Staking approval declined"); + log::info!("Operation cancelled by user: Staking approval declined"); return Err(ProviderError::UserCancelled); } - Console::progress("Approving additional stake"); + log::info!("Approving additional stake"); let approve_tx = self .contracts .ai_token .approve(additional_stake) .await .map_err(|_| ProviderError::Other)?; - Console::info("Transaction approved", &format!("{approve_tx:?}")); + log::info!("Transaction approved: {}", &format!("{approve_tx:?}")); - Console::progress("Increasing stake"); + log::info!("Increasing stake"); let stake_tx = match self.contracts.prime_network.stake(additional_stake).await { Ok(tx) => tx, Err(e) => { @@ -374,17 +363,15 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info( - "Stake increase transaction completed: ", - &format!("{stake_tx:?}"), + log::info!( + "Stake increase transaction completed: {}", + &format!("{stake_tx:?}") ); - Console::success("Provider stake increased successfully"); Ok(()) } - pub(crate) async fn reclaim_stake(&self, amount: U256) -> Result<(), ProviderError> { - Console::progress("Reclaiming stake"); + pub async fn reclaim_stake(&self, amount: U256) -> Result<(), ProviderError> { let reclaim_tx = match self.contracts.prime_network.reclaim_stake(amount).await { Ok(tx) => tx, Err(e) => { @@ -392,17 +379,16 @@ impl ProviderOperations { return Err(ProviderError::Other); } }; - Console::info( - "Stake reclaim transaction completed: ", - &format!("{reclaim_tx:?}"), + log::info!( + "Stake reclaim transaction completed: {}", + &format!("{reclaim_tx:?}") ); - Console::success("Provider stake reclaimed successfully"); Ok(()) } } #[derive(Debug)] -pub(crate) enum ProviderError { +pub enum ProviderError { NotWhitelisted, UserCancelled, Other, diff --git a/crates/prime-protocol-py/.gitignore b/crates/prime-protocol-py/.gitignore new file mode 100644 index 00000000..454f9f33 --- /dev/null +++ b/crates/prime-protocol-py/.gitignore @@ -0,0 +1,24 @@ +# Python +__pycache__/ +*.py[cod] +*.so +*.pyd +*.egg-info/ +dist/ + +# Virtual environments +.venv/ + +# Testing +.pytest_cache/ + +# IDE +.vscode/ +.idea/ + +# Rust/Maturin +target/ +Cargo.lock + +# OS +.DS_Store \ No newline at end of file diff --git a/crates/prime-protocol-py/.python-version b/crates/prime-protocol-py/.python-version new file mode 100644 index 00000000..4b7e4839 --- /dev/null +++ b/crates/prime-protocol-py/.python-version @@ -0,0 +1 @@ +3.11 \ No newline at end of file diff --git a/crates/prime-protocol-py/Cargo.toml b/crates/prime-protocol-py/Cargo.toml new file mode 100644 index 00000000..cbb7b513 --- /dev/null +++ b/crates/prime-protocol-py/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "prime-protocol-py" +version = "0.1.0" +authors = ["Prime Protocol"] +edition = "2021" +rust-version = "1.70" + +[lib] +name = "primeprotocol" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.25.1", features = ["extension-module"] } +thiserror = "1.0" +shared = { workspace = true } +prime-core = { workspace = true } +alloy = { workspace = true } +alloy-provider = { workspace = true } +tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "sync", "time", "macros"] } +url = "2.5" +log = { workspace = true } +pyo3-log = "0.12.4" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +pythonize = "0.25" + +[dev-dependencies] +test-log = "0.2" +tokio-test = "0.4" + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 +strip = true + diff --git a/crates/prime-protocol-py/Makefile b/crates/prime-protocol-py/Makefile new file mode 100644 index 00000000..dfb10ac9 --- /dev/null +++ b/crates/prime-protocol-py/Makefile @@ -0,0 +1,21 @@ +.PHONY: install +install: + @command -v uv > /dev/null || (echo "Please install uv first: curl -LsSf https://astral.sh/uv/install.sh | sh" && exit 1) + @./setup.sh # Uses uv for fast package management + +.PHONY: build +build: install + @uv cache clean + @source .venv/bin/activate && maturin develop + @source .venv/bin/activate && uv pip install --force-reinstall -e . + +.PHONY: clean +clean: + @rm -rf target/ dist/ *.egg-info .pytest_cache __pycache__ .venv/ + +.PHONY: help +help: + @echo "Available commands:" + @echo " make install - Setup environment and install dependencies" + @echo " make build - Build development version (includes install and cache clear)" + @echo " make clean - Clean build artifacts" \ No newline at end of file diff --git a/crates/prime-protocol-py/README.md b/crates/prime-protocol-py/README.md new file mode 100644 index 00000000..b72b39db --- /dev/null +++ b/crates/prime-protocol-py/README.md @@ -0,0 +1,90 @@ +# Prime Protocol Python Client + +## Build + +```bash +# Install uv (one-time) +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Setup and build +cd crates/prime-protocol-py +make install +``` + +## Usage + +### Worker Client with Message Queue + +The Worker Client provides a message queue system for handling P2P messages from pool owners and validators. Messages are processed in a FIFO (First-In-First-Out) manner. + +```python +from primeprotocol import WorkerClient +import asyncio + +# Initialize the worker client +client = WorkerClient( + compute_pool_id=1, + rpc_url="http://localhost:8545", + private_key_provider="your_provider_key", + private_key_node="your_node_key", +) + +# Start the client (registers on-chain and starts message listener) +client.start() + +# Poll for messages in your application loop +async def process_messages(): + while True: + # Get next message from pool owner queue + pool_msg = client.get_pool_owner_message() + if pool_msg: + print(f"Pool owner message: {pool_msg}") + # Process the message... + + # Get next message from validator queue + validator_msg = client.get_validator_message() + if validator_msg: + print(f"Validator message: {validator_msg}") + # Process the message... + + await asyncio.sleep(0.1) + +# Run the message processing loop +asyncio.run(process_messages()) + +# Gracefully shutdown +client.stop() +``` + +### Message Queue Features + +- **Background Listener**: Rust protocol listens for P2P messages in the background +- **FIFO Queue**: Messages are processed in the order they are received +- **Message Types**: Separate queues for pool owner, validator, and system messages +- **Mock Mode**: Currently generates mock messages for testing (P2P integration coming soon) +- **Thread-Safe**: Safe to use from async Python code + +See `examples/message_queue_example.py` for a complete working example. + +## Development + +```bash +make build # Build development version +make test # Run tests +make example # Run example +make clean # Clean artifacts +make help # Show all commands +``` + +## Installing in other projects + +```bash +# Build the wheel +make build-release + +# Install with uv (recommended) +uv pip install target/wheels/primeprotocol-*.whl + +# Or install directly from source +uv pip install /path/to/prime-protocol-py/ +``` \ No newline at end of file diff --git a/crates/prime-protocol-py/examples/basic_usage.py b/crates/prime-protocol-py/examples/basic_usage.py new file mode 100644 index 00000000..02b19bd9 --- /dev/null +++ b/crates/prime-protocol-py/examples/basic_usage.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +"""Example usage of the Prime Protocol Python client.""" + +import asyncio +import logging +import os +import signal +import sys +import time +from typing import Dict, Any, Optional +from primeprotocol import WorkerClient + +FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s' +logging.basicConfig(format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + + +def handle_pool_owner_message(message: Dict[str, Any]) -> None: + """Handle messages from pool owner""" + logging.info(f"Received message from pool owner: {message}") + + if message.get("type") == "inference_request": + prompt = message.get("prompt", "") + # Simulate processing the inference request + response = f"Processed: {prompt}" + + logging.info(f"Processing inference request: {prompt}") + logging.info(f"Generated response: {response}") + + # In a real implementation, you would send the response back + # client.send_response({"type": "inference_response", "result": response}) + else: + logging.info("Sending PONG response") + # client.send_response("PONG") + + +def handle_validator_message(message: Dict[str, Any]) -> None: + """Handle messages from validator""" + logging.info(f"Received message from validator: {message}") + + if message.get("type") == "inference_request": + prompt = message.get("prompt", "") + # Simulate processing the inference request + response = f"Validated: {prompt}" + + logging.info(f"Processing validation request: {prompt}") + logging.info(f"Generated response: {response}") + + # In a real implementation, you would send the response back + # client.send_response({"type": "inference_response", "result": response}) + + +def check_for_messages(client: WorkerClient) -> None: + """Check for new messages from pool owner and validator""" + try: + # Check for pool owner messages + pool_owner_message = client.get_pool_owner_message() + if pool_owner_message: + handle_pool_owner_message(pool_owner_message) + + # Check for validator messages + validator_message = client.get_validator_message() + if validator_message: + handle_validator_message(validator_message) + + except Exception as e: + logging.error(f"Error checking for messages: {e}") + + +def main(): + rpc_url = os.getenv("RPC_URL", "http://localhost:8545") + pool_id = os.getenv("POOL_ID", 0) + private_key_provider = os.getenv("PRIVATE_KEY_PROVIDER", None) + private_key_node = os.getenv("PRIVATE_KEY_NODE", None) + + logging.info(f"Connecting to: {rpc_url}") + client = WorkerClient(pool_id, rpc_url, private_key_provider, private_key_node) + + def signal_handler(sig, frame): + logging.info("Received interrupt signal, shutting down gracefully...") + try: + client.stop() + logging.info("Client stopped successfully") + except Exception as e: + logging.error(f"Error during shutdown: {e}") + sys.exit(0) + + # Register signal handler for Ctrl+C before starting client + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + logging.info("Starting client... (Press Ctrl+C to interrupt)") + client.start() + logging.info("Setup completed. Starting message polling loop...") + print("Worker client started. Polling for messages. Press Ctrl+C to stop.") + + # Message polling loop + while True: + try: + check_for_messages(client) + time.sleep(0.1) # Small delay to prevent busy waiting + except KeyboardInterrupt: + # Handle Ctrl+C during message polling + logging.info("Keyboard interrupt received during polling") + signal_handler(signal.SIGINT, None) + break + + except KeyboardInterrupt: + # Handle Ctrl+C during client startup + logging.info("Keyboard interrupt received during startup") + signal_handler(signal.SIGINT, None) + except Exception as e: + logging.error(f"Unexpected error: {e}") + try: + client.stop() + except: + pass + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/prime-protocol-py/pyproject.toml b/crates/prime-protocol-py/pyproject.toml new file mode 100644 index 00000000..9834d8b4 --- /dev/null +++ b/crates/prime-protocol-py/pyproject.toml @@ -0,0 +1,37 @@ +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[project] +name = "primeprotocol" +description = "Simple Python bindings for Prime Protocol client" +readme = "README.md" +requires-python = ">=3.8" +license = {text = "MIT"} +keywords = ["prime", "protocol"] +authors = [ + {name = "Prime Protocol", email = "jannik@primeintellect.ai"} +] +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dynamic = ["version"] + +[project.urls] +"Homepage" = "https://github.com/primeprotocol/protocol" +"Bug Tracker" = "https://github.com/primeprotocol/protocol/issues" + +[tool.maturin] +features = ["pyo3/extension-module"] +module-name = "primeprotocol" \ No newline at end of file diff --git a/crates/prime-protocol-py/requirements-dev.txt b/crates/prime-protocol-py/requirements-dev.txt new file mode 100644 index 00000000..f2af3c5d --- /dev/null +++ b/crates/prime-protocol-py/requirements-dev.txt @@ -0,0 +1,3 @@ +# Development dependencies +maturin>=1.0,<2.0 +pytest>=7.0 \ No newline at end of file diff --git a/crates/prime-protocol-py/setup.sh b/crates/prime-protocol-py/setup.sh new file mode 100755 index 00000000..7609b236 --- /dev/null +++ b/crates/prime-protocol-py/setup.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -e + +# Check if uv is installed +if ! command -v uv &> /dev/null; then + echo "Please install uv first: curl -LsSf https://astral.sh/uv/install.sh | sh" + exit 1 +fi + +# Setup environment +uv venv +source .venv/bin/activate +uv pip install -r requirements-dev.txt +maturin develop + +echo "Setup complete." \ No newline at end of file diff --git a/crates/prime-protocol-py/src/error.rs b/crates/prime-protocol-py/src/error.rs new file mode 100644 index 00000000..cf561595 --- /dev/null +++ b/crates/prime-protocol-py/src/error.rs @@ -0,0 +1,21 @@ +use thiserror::Error; + +/// Result type alias for Prime Protocol operations +pub type Result = std::result::Result; + +/// Errors that can occur in the Prime Protocol client +#[derive(Debug, Error)] +pub enum PrimeProtocolError { + /// Invalid configuration provided + #[error("Invalid configuration: {0}")] + InvalidConfig(String), + + /// Blockchain interaction error + #[error("Blockchain error: {0}")] + BlockchainError(String), + + /// General runtime error + #[error("Runtime error: {0}")] + #[allow(dead_code)] + RuntimeError(String), +} diff --git a/crates/prime-protocol-py/src/lib.rs b/crates/prime-protocol-py/src/lib.rs new file mode 100644 index 00000000..0715c33a --- /dev/null +++ b/crates/prime-protocol-py/src/lib.rs @@ -0,0 +1,19 @@ +use crate::orchestrator::OrchestratorClient; +use crate::validator::ValidatorClient; +use crate::worker::WorkerClient; +use pyo3::prelude::*; + +mod error; +mod orchestrator; +mod utils; +mod validator; +mod worker; + +#[pymodule] +fn primeprotocol(m: &Bound<'_, PyModule>) -> PyResult<()> { + pyo3_log::init(); + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/crates/prime-protocol-py/src/orchestrator/mod.rs b/crates/prime-protocol-py/src/orchestrator/mod.rs new file mode 100644 index 00000000..c610ea6f --- /dev/null +++ b/crates/prime-protocol-py/src/orchestrator/mod.rs @@ -0,0 +1,55 @@ +use pyo3::prelude::*; + +/// Prime Protocol Orchestrator Client - for managing and distributing tasks +#[pyclass] +pub struct OrchestratorClient { + // TODO: Implement orchestrator-specific functionality +} + +#[pymethods] +impl OrchestratorClient { + #[new] + #[pyo3(signature = (rpc_url, private_key=None))] + pub fn new(rpc_url: String, private_key: Option) -> PyResult { + // TODO: Implement orchestrator initialization + let _ = rpc_url; + let _ = private_key; + Ok(Self {}) + } + + pub fn list_validated_nodes(&self) -> PyResult> { + // TODO: Implement orchestrator node listing + Ok(vec![]) + } + + pub fn list_nodes_from_chain(&self) -> PyResult> { + // TODO: Implement orchestrator node listing from chain + Ok(vec![]) + } + + // pub fn get_node_details(&self, node_id: String) -> PyResult> { + // // TODO: Implement orchestrator node details fetching + // Ok(None) + // } + + // pub fn get_node_details_from_chain(&self, node_id: String) -> PyResult> { + // // TODO: Implement orchestrator node details fetching from chain + // Ok(None) + // } + + // pub fn send_invite_to_node(&self, node_id: String) -> PyResult<()> { + // // TODO: Implement orchestrator node invite sending + // Ok(()) + // } + + // pub fn send_request_to_node(&self, node_id: String, request: String) -> PyResult<()> { + // // TODO: Implement orchestrator node request sending + // Ok(()) + // } + + // // TODO: Sender of this message? + // pub fn read_message(&self) -> PyResult> { + // // TODO: Implement orchestrator message reading + // Ok(None) + // } +} diff --git a/crates/prime-protocol-py/src/utils/json_parser.rs b/crates/prime-protocol-py/src/utils/json_parser.rs new file mode 100644 index 00000000..b5ed4aa2 --- /dev/null +++ b/crates/prime-protocol-py/src/utils/json_parser.rs @@ -0,0 +1,8 @@ +use pyo3::prelude::*; +use pythonize::pythonize; + +/// Convert a serde_json::Value to a Python object +pub fn json_to_pyobject(py: Python, value: &serde_json::Value) -> PyObject { + // pythonize handles all the conversion automatically! + pythonize(py, value).unwrap().into() +} diff --git a/crates/prime-protocol-py/src/utils/message_queue.rs b/crates/prime-protocol-py/src/utils/message_queue.rs new file mode 100644 index 00000000..43153cb1 --- /dev/null +++ b/crates/prime-protocol-py/src/utils/message_queue.rs @@ -0,0 +1,152 @@ +use pyo3::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::VecDeque; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::Mutex; +use tokio::time::{interval, Duration}; + +use crate::utils::json_parser::json_to_pyobject; + +/// Generic message that can be sent between components +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub content: serde_json::Value, + pub timestamp: u64, + pub sender: Option, +} + +/// Simple message queue for handling messages +#[derive(Clone)] +pub struct MessageQueue { + queue: Arc>>, + max_size: Option, + shutdown_tx: Arc>>>, +} + +impl MessageQueue { + /// Create a new message queue + pub fn new(max_size: Option) -> Self { + Self { + queue: Arc::new(Mutex::new(VecDeque::new())), + max_size, + shutdown_tx: Arc::new(Mutex::new(None)), + } + } + + /// Push a message to the queue + pub async fn push_message(&self, message: Message) -> Result<(), String> { + let mut queue = self.queue.lock().await; + + // Check max size if configured + if let Some(max_size) = self.max_size { + if queue.len() >= max_size { + return Err(format!("Queue is full (max size: {})", max_size)); + } + } + + queue.push_back(message); + Ok(()) + } + + /// Get the next message from the queue + pub async fn get_message(&self) -> Option { + let mut queue = self.queue.lock().await; + + queue + .pop_front() + .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + } + + /// Get all messages from the queue (draining it) + pub async fn get_all_messages(&self) -> Vec { + let mut queue = self.queue.lock().await; + + let messages: Vec = queue.drain(..).collect(); + messages + .into_iter() + .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + .collect() + } + + /// Peek at the next message without removing it + pub async fn peek_message(&self) -> Option { + let queue = self.queue.lock().await; + + queue + .front() + .map(|msg| Python::with_gil(|py| json_to_pyobject(py, &msg.content))) + } + + /// Get the size of the queue + pub async fn get_queue_size(&self) -> usize { + let queue = self.queue.lock().await; + queue.len() + } + + /// Clear the queue + pub async fn clear(&self) -> Result<(), String> { + let mut queue = self.queue.lock().await; + queue.clear(); + Ok(()) + } + + /// Start a mock message listener (for testing/development) + pub async fn start_mock_listener(&self, frequency: u64) -> Result<(), String> { + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); + + // Store the shutdown sender + { + let mut tx_guard = self.shutdown_tx.lock().await; + *tx_guard = Some(shutdown_tx); + } + + let queue_clone = self.queue.clone(); + + // Spawn background task to simulate incoming messages + tokio::spawn(async move { + let mut ticker = interval(Duration::from_secs(1)); + let mut counter = 0u64; + + loop { + tokio::select! { + _ = ticker.tick() => { + if counter % frequency == 0 { + let message = Message { + content: serde_json::json!({ + "type": "mock_message", + "id": format!("mock_{}", counter), + "data": format!("Mock data #{}", counter), + }), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + sender: Some("mock_listener".to_string()), + }; + + let mut queue = queue_clone.lock().await; + queue.push_back(message); + log::debug!("Added mock message to queue"); + } + counter += 1; + } + _ = shutdown_rx.recv() => { + log::info!("Mock message listener shutting down"); + break; + } + } + } + }); + + Ok(()) + } + + /// Stop the mock listener + pub async fn stop_listener(&self) -> Result<(), String> { + if let Some(tx) = self.shutdown_tx.lock().await.take() { + let _ = tx.send(()).await; + } + Ok(()) + } +} diff --git a/crates/prime-protocol-py/src/utils/mod.rs b/crates/prime-protocol-py/src/utils/mod.rs new file mode 100644 index 00000000..da6afad7 --- /dev/null +++ b/crates/prime-protocol-py/src/utils/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod json_parser; +pub(crate) mod message_queue; diff --git a/crates/prime-protocol-py/src/validator/message_queue.rs b/crates/prime-protocol-py/src/validator/message_queue.rs new file mode 100644 index 00000000..72f1b468 --- /dev/null +++ b/crates/prime-protocol-py/src/validator/message_queue.rs @@ -0,0 +1,46 @@ +use crate::utils::message_queue::{Message, MessageQueue as GenericMessageQueue}; +use pyo3::prelude::*; + +/// Validator-specific message queue for incoming validation results +#[derive(Clone)] +pub struct MessageQueue { + inner: GenericMessageQueue, +} + +impl MessageQueue { + /// Create a new validator message queue for validation results + pub fn new() -> Self { + let inner = GenericMessageQueue::new(None); + + Self { inner } + } + + /// Get the next validation result from nodes + pub async fn get_validation_result(&self) -> Option { + self.inner.get_message().await + } + + /// Push a validation result (for testing or internal use) + pub async fn push_validation_result(&self, content: serde_json::Value) -> Result<(), String> { + let message = Message { + content, + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + sender: None, // Will be set to the node ID when implemented + }; + + self.inner.push_message(message).await + } + + /// Get the number of pending validation results + pub async fn get_queue_size(&self) -> usize { + self.inner.get_queue_size().await + } + + /// Clear all validation results (use with caution) + pub async fn clear(&self) -> Result<(), String> { + self.inner.clear().await + } +} diff --git a/crates/prime-protocol-py/src/validator/mod.rs b/crates/prime-protocol-py/src/validator/mod.rs new file mode 100644 index 00000000..6890e799 --- /dev/null +++ b/crates/prime-protocol-py/src/validator/mod.rs @@ -0,0 +1,108 @@ +use pyo3::prelude::*; + +pub(crate) mod message_queue; +use self::message_queue::MessageQueue; + +/// Node details for validator operations +#[pyclass] +#[derive(Clone)] +pub(crate) struct NodeDetails { + #[pyo3(get)] + pub address: String, +} + +#[pymethods] +impl NodeDetails { + #[new] + pub fn new(address: String) -> Self { + Self { address } + } +} + +/// Prime Protocol Validator Client - for validating task results +#[pyclass] +pub(crate) struct ValidatorClient { + message_queue: MessageQueue, + runtime: Option, +} + +#[pymethods] +impl ValidatorClient { + #[new] + #[pyo3(signature = (rpc_url, private_key=None))] + pub fn new(rpc_url: String, private_key: Option) -> PyResult { + // TODO: Implement validator initialization + let _ = rpc_url; + let _ = private_key; + + Ok(Self { + message_queue: MessageQueue::new(), + runtime: None, + }) + } + + /// Initialize the validator client and start listening for messages + pub fn start(&mut self, py: Python) -> PyResult<()> { + // Create a new runtime for this validator + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Store the runtime for future use + self.runtime = Some(rt); + + Ok(()) + } + + pub fn list_nodes(&self) -> PyResult> { + // TODO: Implement validator node listing from chain that are not yet validated + Ok(vec![]) + } + + pub fn fetch_node_details(&self, node_id: String) -> PyResult> { + // TODO: Implement validator node details fetching + Ok(None) + } + + pub fn mark_node_as_validated(&self, node_id: String) -> PyResult<()> { + // TODO: Implement validator node marking as validated + Ok(()) + } + + pub fn send_request_to_node(&self, node_id: String, request: String) -> PyResult<()> { + // TODO: Implement validator node request sending + Ok(()) + } + + pub fn send_request_to_node_address( + &self, + node_address: String, + request: String, + ) -> PyResult<()> { + // TODO: Implement validator node request sending to specific address + let _ = node_address; + let _ = request; + Ok(()) + } + + /// Get the latest validation result from the internal message queue + pub fn get_latest_message(&self, py: Python) -> PyResult> { + if let Some(rt) = self.runtime.as_ref() { + Ok(py.allow_threads(|| rt.block_on(self.message_queue.get_validation_result()))) + } else { + Err(PyErr::new::( + "Validator not started. Call start() first.".to_string(), + )) + } + } + + /// Get the number of pending validation results + pub fn get_queue_size(&self, py: Python) -> PyResult { + if let Some(rt) = self.runtime.as_ref() { + Ok(py.allow_threads(|| rt.block_on(self.message_queue.get_queue_size()))) + } else { + Ok(0) + } + } +} diff --git a/crates/prime-protocol-py/src/worker/client.rs b/crates/prime-protocol-py/src/worker/client.rs new file mode 100644 index 00000000..db30c0b4 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/client.rs @@ -0,0 +1,439 @@ +use crate::error::{PrimeProtocolError, Result}; +use crate::worker::message_queue::MessageQueue; +use alloy::primitives::utils::format_ether; +use alloy::primitives::{Address, U256}; +use prime_core::operations::compute_node::ComputeNodeOperations; +use prime_core::operations::provider::ProviderOperations; +use shared::web3::contracts::core::builder::{ContractBuilder, Contracts}; +use shared::web3::contracts::structs::compute_pool::PoolStatus; +use shared::web3::wallet::{Wallet, WalletProvider}; +use std::sync::Arc; +use url::Url; + +pub struct WorkerClientCore { + rpc_url: String, + compute_pool_id: u64, + private_key_provider: Option, + private_key_node: Option, + auto_accept_transactions: bool, + funding_retry_count: u32, + message_queue: Arc, +} + +impl WorkerClientCore { + pub fn new( + compute_pool_id: u64, + rpc_url: String, + private_key_provider: Option, + private_key_node: Option, + auto_accept_transactions: Option, + funding_retry_count: Option, + ) -> Result { + if rpc_url.is_empty() { + return Err(PrimeProtocolError::InvalidConfig( + "RPC URL cannot be empty".to_string(), + )); + } + + Url::parse(&rpc_url) + .map_err(|_| PrimeProtocolError::InvalidConfig("Invalid RPC URL format".to_string()))?; + + Ok(Self { + rpc_url, + compute_pool_id, + private_key_provider, + private_key_node, + auto_accept_transactions: auto_accept_transactions.unwrap_or(true), + funding_retry_count: funding_retry_count.unwrap_or(10), + message_queue: Arc::new(MessageQueue::new()), + }) + } + + pub async fn start_async(&self) -> Result<()> { + let (provider_wallet, node_wallet, contracts) = + self.initialize_blockchain_components().await?; + let pool_info = self.wait_for_active_pool(&contracts).await?; + + log::debug!("Pool info: {:?}", pool_info); + log::debug!("Checking provider"); + self.ensure_provider_registered(&provider_wallet, &contracts) + .await?; + log::debug!("Checking compute node"); + self.ensure_compute_node_registered(&provider_wallet, &node_wallet, &contracts) + .await?; + + log::debug!("blockchain components initialized"); + log::debug!("starting queues"); + + // Start the message queue listener + self.message_queue.start_listener().await.map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to start message listener: {}", e)) + })?; + + log::debug!("Message queue listener started"); + + Ok(()) + } + + async fn initialize_blockchain_components( + &self, + ) -> Result<(Wallet, Wallet, Contracts)> { + let private_key_provider = self.get_private_key_provider()?; + let private_key_node = self.get_private_key_node()?; + let rpc_url = Url::parse(&self.rpc_url).unwrap(); + + let provider_wallet = Wallet::new(&private_key_provider, rpc_url.clone()).map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to create provider wallet: {}", e)) + })?; + + let node_wallet = Wallet::new(&private_key_node, rpc_url.clone()).map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to create node wallet: {}", e)) + })?; + + let contracts = ContractBuilder::new(provider_wallet.provider()) + .with_compute_pool() + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_stake_manager() + .build() + .map_err(|e| PrimeProtocolError::BlockchainError(e.to_string()))?; + + Ok((provider_wallet, node_wallet, contracts)) + } + + async fn wait_for_active_pool( + &self, + contracts: &Contracts, + ) -> Result { + loop { + match contracts + .compute_pool + .get_pool_info(U256::from(self.compute_pool_id)) + .await + { + Ok(pool) if pool.status == PoolStatus::ACTIVE => return Ok(pool), + Ok(_) => { + log::info!("Pool not active yet, waiting..."); + tokio::time::sleep(tokio::time::Duration::from_secs(15)).await; + } + Err(e) => { + return Err(PrimeProtocolError::BlockchainError(format!( + "Failed to get pool info: {}", + e + ))); + } + } + } + } + + async fn ensure_provider_registered( + &self, + provider_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let provider_ops = ProviderOperations::new( + provider_wallet.clone(), + contracts.clone(), + self.auto_accept_transactions, + ); + + let provider_exists = self.check_provider_exists(&provider_ops).await?; + let is_whitelisted = self.check_provider_whitelisted(&provider_ops).await?; + + if provider_exists && is_whitelisted { + log::info!("Provider is registered and whitelisted"); + } else { + self.register_provider_if_needed(&provider_ops, contracts) + .await?; + } + + self.ensure_adequate_stake(&provider_ops, provider_wallet, contracts) + .await?; + + Ok(()) + } + + async fn check_provider_exists(&self, provider_ops: &ProviderOperations) -> Result { + provider_ops.check_provider_exists().await.map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to check if provider exists: {}", + e + )) + }) + } + + async fn check_provider_whitelisted(&self, provider_ops: &ProviderOperations) -> Result { + provider_ops + .check_provider_whitelisted() + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to check provider whitelist status: {}", + e + )) + }) + } + + async fn register_provider_if_needed( + &self, + provider_ops: &ProviderOperations, + contracts: &Contracts, + ) -> Result<()> { + let stake_manager = contracts.stake_manager.as_ref().ok_or_else(|| { + PrimeProtocolError::BlockchainError("Stake manager not initialized".to_string()) + })?; + let compute_units = U256::from(1); // TODO: Make configurable + + let required_stake = stake_manager + .calculate_stake(compute_units, U256::from(0)) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to calculate required stake: {}", + e + )) + })?; + + log::info!("Required stake: {}", format_ether(required_stake)); + + // Add timeout to prevent hanging on blockchain operations + let register_future = + provider_ops.retry_register_provider(required_stake, self.funding_retry_count, None); + + tokio::time::timeout( + tokio::time::Duration::from_secs(300), // 5 minute timeout + register_future, + ) + .await + .map_err(|_| { + PrimeProtocolError::BlockchainError( + "Provider registration timed out after 5 minutes".to_string(), + ) + })? + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to register provider: {}", e)) + })?; + + log::info!("Provider registered successfully"); + Ok(()) + } + + async fn ensure_adequate_stake( + &self, + provider_ops: &ProviderOperations, + provider_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let stake_manager = contracts.stake_manager.as_ref().ok_or_else(|| { + PrimeProtocolError::BlockchainError("Stake manager not initialized".to_string()) + })?; + let provider_address = provider_wallet.wallet.default_signer().address(); + + let provider_total_compute = self + .get_provider_total_compute(contracts, provider_address) + .await?; + let provider_stake = self.get_provider_stake(contracts, provider_address).await; + let compute_units = U256::from(1); // TODO: Make configurable + + let required_stake = stake_manager + .calculate_stake(compute_units, provider_total_compute) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to calculate required stake: {}", + e + )) + })?; + + if required_stake > provider_stake { + self.increase_provider_stake(provider_ops, required_stake, provider_stake) + .await?; + } + + Ok(()) + } + + async fn get_provider_total_compute( + &self, + contracts: &Contracts, + provider_address: Address, + ) -> Result { + contracts + .compute_registry + .get_provider_total_compute(provider_address) + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to get provider total compute: {}", + e + )) + }) + } + + async fn get_provider_stake( + &self, + contracts: &Contracts, + provider_address: Address, + ) -> U256 { + let stake_manager = contracts.stake_manager.as_ref(); + match stake_manager { + Some(manager) => manager + .get_stake(provider_address) + .await + .unwrap_or_default(), + None => U256::from(0), + } + } + + async fn increase_provider_stake( + &self, + provider_ops: &ProviderOperations, + required_stake: U256, + current_stake: U256, + ) -> Result<()> { + log::info!( + "Provider stake is less than required stake. Required: {} tokens, Current: {} tokens", + format_ether(required_stake), + format_ether(current_stake) + ); + + // Add timeout to prevent hanging on stake increase operations + let stake_future = provider_ops.increase_stake(required_stake - current_stake); + + tokio::time::timeout( + tokio::time::Duration::from_secs(300), // 5 minute timeout + stake_future, + ) + .await + .map_err(|_| { + PrimeProtocolError::BlockchainError( + "Stake increase timed out after 5 minutes".to_string(), + ) + })? + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to increase stake: {}", e)) + })?; + + log::info!("Successfully increased stake"); + Ok(()) + } + + async fn ensure_compute_node_registered( + &self, + provider_wallet: &Wallet, + node_wallet: &Wallet, + contracts: &Contracts, + ) -> Result<()> { + let compute_node_ops = + ComputeNodeOperations::new(provider_wallet, node_wallet, contracts.clone()); + + let compute_node_exists = self.check_compute_node_exists(&compute_node_ops).await?; + + if compute_node_exists { + log::info!("Compute node is already registered"); + return Ok(()); + } + + self.register_compute_node(&compute_node_ops).await?; + Ok(()) + } + + async fn check_compute_node_exists( + &self, + compute_node_ops: &ComputeNodeOperations<'_>, + ) -> Result { + compute_node_ops + .check_compute_node_exists() + .await + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!( + "Failed to check if compute node exists: {}", + e + )) + }) + } + + async fn register_compute_node( + &self, + compute_node_ops: &ComputeNodeOperations<'_>, + ) -> Result<()> { + let compute_units = U256::from(1); // TODO: Make configurable + + // Add timeout to prevent hanging on compute node registration + let register_future = compute_node_ops.add_compute_node(compute_units); + + tokio::time::timeout( + tokio::time::Duration::from_secs(300), // 5 minute timeout + register_future, + ) + .await + .map_err(|_| { + PrimeProtocolError::BlockchainError( + "Compute node registration timed out after 5 minutes".to_string(), + ) + })? + .map_err(|e| { + PrimeProtocolError::BlockchainError(format!("Failed to register compute node: {}", e)) + })?; + + log::info!("Compute node registered successfully"); + Ok(()) + } + + fn get_private_key_provider(&self) -> Result { + match &self.private_key_provider { + Some(key) => Ok(key.clone()), + None => std::env::var("PRIVATE_KEY_PROVIDER").map_err(|_| { + PrimeProtocolError::InvalidConfig("PRIVATE_KEY_PROVIDER must be set".to_string()) + }), + } + } + + fn get_private_key_node(&self) -> Result { + match &self.private_key_node { + Some(key) => Ok(key.clone()), + None => std::env::var("PRIVATE_KEY_NODE").map_err(|_| { + PrimeProtocolError::InvalidConfig("PRIVATE_KEY_NODE must be set".to_string()) + }), + } + } + + /// Get the shared message queue instance + pub fn get_message_queue(&self) -> Arc { + self.message_queue.clone() + } + + /// Stop the message queue listener + pub async fn stop_async(&self) -> Result<()> { + self.message_queue.stop_listener().await.map_err(|e| { + PrimeProtocolError::InvalidConfig(format!("Failed to stop message listener: {}", e)) + })?; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use test_log::test; + + #[test(tokio::test)] + async fn test_start_async() { + // standard anvil blockchain keys for local testing + let node_key = "0x7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6"; + let provider_key = "0x5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a"; + + // todo: currently still have to make up the local blockchain incl. smart contract deployments + let worker = WorkerClientCore::new( + 0, + "http://localhost:8545".to_string(), + Some(provider_key.to_string()), + Some(node_key.to_string()), + None, + None, + ) + .unwrap(); + worker.start_async().await.unwrap(); + } +} diff --git a/crates/prime-protocol-py/src/worker/message_queue.rs b/crates/prime-protocol-py/src/worker/message_queue.rs new file mode 100644 index 00000000..167fde05 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/message_queue.rs @@ -0,0 +1,80 @@ +use crate::utils::message_queue::{Message, MessageQueue as GenericMessageQueue}; +use pyo3::prelude::*; + +/// Queue types for the worker message queue +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum QueueType { + PoolOwner, + Validator, +} + +/// Worker-specific message queue with predefined queue types +#[derive(Clone)] +pub struct MessageQueue { + pool_owner_queue: GenericMessageQueue, + validator_queue: GenericMessageQueue, +} + +impl MessageQueue { + /// Create a new worker message queue with pool_owner and validator queues + pub fn new() -> Self { + Self { + pool_owner_queue: GenericMessageQueue::new(None), + validator_queue: GenericMessageQueue::new(None), + } + } + + /// Start the background message listener for worker + pub(crate) async fn start_listener(&self) -> Result<(), String> { + // Start mock listeners with different frequencies + // pool_owner messages every 2 seconds, validator messages every 3 seconds + self.pool_owner_queue.start_mock_listener(2).await?; + self.validator_queue.start_mock_listener(3).await?; + Ok(()) + } + + /// Stop the background listener + pub(crate) async fn stop_listener(&self) -> Result<(), String> { + self.pool_owner_queue.stop_listener().await?; + self.validator_queue.stop_listener().await?; + Ok(()) + } + + /// Get the next message from the pool owner queue + pub(crate) async fn get_pool_owner_message(&self) -> Option { + self.pool_owner_queue.get_message().await + } + + /// Get the next message from the validator queue + pub(crate) async fn get_validator_message(&self) -> Option { + self.validator_queue.get_message().await + } + + /// Push a message to the appropriate queue (for testing or internal use) + pub(crate) async fn push_message( + &self, + queue_type: QueueType, + content: serde_json::Value, + ) -> Result<(), String> { + let message = Message { + content, + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + sender: Some("worker".to_string()), + }; + + match queue_type { + QueueType::PoolOwner => self.pool_owner_queue.push_message(message).await, + QueueType::Validator => self.validator_queue.push_message(message).await, + } + } + + /// Get queue sizes for monitoring + pub(crate) async fn get_queue_sizes(&self) -> (usize, usize) { + let pool_owner_size = self.pool_owner_queue.get_queue_size().await; + let validator_size = self.validator_queue.get_queue_size().await; + (pool_owner_size, validator_size) + } +} diff --git a/crates/prime-protocol-py/src/worker/mod.rs b/crates/prime-protocol-py/src/worker/mod.rs new file mode 100644 index 00000000..a308df12 --- /dev/null +++ b/crates/prime-protocol-py/src/worker/mod.rs @@ -0,0 +1,91 @@ +use pyo3::prelude::*; +mod client; +pub(crate) mod message_queue; +pub(crate) use client::WorkerClientCore; +/// Prime Protocol Worker Client - for compute nodes that execute tasks +#[pyclass] +pub(crate) struct WorkerClient { + inner: WorkerClientCore, + runtime: Option, +} + +#[pymethods] +impl WorkerClient { + #[new] + #[pyo3(signature = (compute_pool_id, rpc_url, private_key_provider=None, private_key_node=None))] + pub fn new( + compute_pool_id: u64, + rpc_url: String, + private_key_provider: Option, + private_key_node: Option, + ) -> PyResult { + let inner = WorkerClientCore::new( + compute_pool_id, + rpc_url, + private_key_provider, + private_key_node, + None, + None, + ) + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(Self { + inner, + runtime: None, + }) + } + + pub fn start(&mut self, py: Python) -> PyResult<()> { + // Create a new runtime for this call + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + // Run the async function with GIL released + let result = py.allow_threads(|| rt.block_on(self.inner.start_async())); + + // Store the runtime for future use + self.runtime = Some(rt); + + result.map_err(|e| PyErr::new::(e.to_string())) + } + + pub fn get_pool_owner_message(&self, py: Python) -> PyResult> { + if let Some(rt) = self.runtime.as_ref() { + Ok(py.allow_threads(|| { + rt.block_on(self.inner.get_message_queue().get_pool_owner_message()) + })) + } else { + Err(PyErr::new::( + "Client not started. Call start() first.".to_string(), + )) + } + } + + pub fn get_validator_message(&self, py: Python) -> PyResult> { + if let Some(rt) = self.runtime.as_ref() { + Ok(py.allow_threads(|| { + rt.block_on(self.inner.get_message_queue().get_validator_message()) + })) + } else { + Err(PyErr::new::( + "Client not started. Call start() first.".to_string(), + )) + } + } + + pub fn stop(&mut self, py: Python) -> PyResult<()> { + if let Some(rt) = self.runtime.as_ref() { + py.allow_threads(|| rt.block_on(self.inner.stop_async())) + .map_err(|e| PyErr::new::(e.to_string()))?; + } + + // Clean up the runtime + if let Some(rt) = self.runtime.take() { + rt.shutdown_background(); + } + + Ok(()) + } +} diff --git a/crates/prime-protocol-py/tests/test_client.py b/crates/prime-protocol-py/tests/test_client.py new file mode 100644 index 00000000..57b02400 --- /dev/null +++ b/crates/prime-protocol-py/tests/test_client.py @@ -0,0 +1,29 @@ +"""Basic tests for the Prime Protocol Python client.""" + +import pytest +from primeprotocol import PrimeProtocolClient + + +def test_client_creation(): + """Test that client can be created with valid RPC URL.""" + client = PrimeProtocolClient("http://localhost:8545") + assert client is not None + + +def test_client_creation_with_empty_url(): + """Test that client creation fails with empty RPC URL.""" + with pytest.raises(ValueError): + PrimeProtocolClient("") + + +def test_client_creation_with_invalid_url(): + """Test that client creation fails with invalid RPC URL.""" + with pytest.raises(ValueError): + PrimeProtocolClient("not-a-valid-url") + + +def test_has_compute_pool_exists_method(): + """Test that the client has the compute_pool_exists method.""" + client = PrimeProtocolClient("http://example.com:8545") + assert hasattr(client, 'compute_pool_exists') + assert callable(getattr(client, 'compute_pool_exists')) \ No newline at end of file diff --git a/crates/prime-protocol-py/uv.lock b/crates/prime-protocol-py/uv.lock new file mode 100644 index 00000000..639a70ba --- /dev/null +++ b/crates/prime-protocol-py/uv.lock @@ -0,0 +1,7 @@ +version = 1 +requires-python = ">=3.8" + +[[package]] +name = "primeprotocol" +version = "0.1.0" +source = { editable = "." } diff --git a/crates/shared/Cargo.toml b/crates/shared/Cargo.toml index 9afdafff..4d3a8760 100644 --- a/crates/shared/Cargo.toml +++ b/crates/shared/Cargo.toml @@ -15,6 +15,8 @@ default = [] testnet = [] [dependencies] +p2p = { workspace = true} + tokio = { workspace = true } alloy = { workspace = true } alloy-provider = { workspace = true } @@ -40,3 +42,5 @@ iroh = { workspace = true } rand_v8 = { workspace = true } subtle = "2.6.1" utoipa = { version = "5.3.0", features = ["actix_extras", "chrono", "uuid"] } +futures = { workspace = true } +tokio-util = { workspace = true } diff --git a/crates/shared/src/models/invite.rs b/crates/shared/src/models/invite.rs deleted file mode 100644 index 08cf2a5e..00000000 --- a/crates/shared/src/models/invite.rs +++ /dev/null @@ -1,20 +0,0 @@ -use serde::Deserialize; -use serde::Serialize; - -#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] -pub struct InviteRequest { - pub invite: String, - pub pool_id: u32, - // Either master url or ip and port - pub master_url: Option, - pub master_ip: Option, - pub master_port: Option, - pub timestamp: u64, - pub expiration: [u8; 32], - pub nonce: [u8; 32], -} - -#[derive(Deserialize, Serialize)] -pub struct InviteResponse { - pub status: String, -} diff --git a/crates/shared/src/models/metric.rs b/crates/shared/src/models/metric.rs index 47b27f24..b85c4926 100644 --- a/crates/shared/src/models/metric.rs +++ b/crates/shared/src/models/metric.rs @@ -58,7 +58,7 @@ mod tests { let invalid_values = vec![(f64::INFINITY, "infinite value"), (f64::NAN, "NaN value")]; for (value, case) in invalid_values { let entry = MetricEntry::new(key.clone(), value); - assert!(entry.is_err(), "Should fail for {}", case); + assert!(entry.is_err(), "Should fail for {case}"); } } diff --git a/crates/shared/src/models/mod.rs b/crates/shared/src/models/mod.rs index 0bbe8968..dea669b3 100644 --- a/crates/shared/src/models/mod.rs +++ b/crates/shared/src/models/mod.rs @@ -1,7 +1,5 @@ pub mod api; -pub mod challenge; pub mod heartbeat; -pub mod invite; pub mod metric; pub mod node; pub mod storage; diff --git a/crates/shared/src/p2p/client.rs b/crates/shared/src/p2p/client.rs deleted file mode 100644 index 54e6de45..00000000 --- a/crates/shared/src/p2p/client.rs +++ /dev/null @@ -1,237 +0,0 @@ -use alloy::primitives::Address; -use anyhow::Result; -use iroh::endpoint::{RecvStream, SendStream}; -use iroh::{Endpoint, NodeAddr, NodeId, RelayMode, SecretKey}; -use log::{debug, info}; -use std::str::FromStr; -use std::time::Duration; - -use crate::p2p::messages::{P2PMessage, P2PRequest, P2PResponse}; -use crate::p2p::protocol::PRIME_P2P_PROTOCOL; -use crate::security::request_signer::sign_message; -use crate::web3::wallet::Wallet; -use rand_v8::rngs::OsRng; -use rand_v8::Rng; - -pub struct P2PClient { - endpoint: Endpoint, - node_id: NodeId, - wallet: Wallet, -} - -impl P2PClient { - pub async fn new(wallet: Wallet) -> Result { - let mut rng = rand_v8::thread_rng(); - let secret_key = SecretKey::generate(&mut rng); - let node_id = secret_key.public(); - - let endpoint = Endpoint::builder() - .secret_key(secret_key) - .alpns(vec![PRIME_P2P_PROTOCOL.to_vec()]) - .relay_mode(RelayMode::Default) - .discovery_n0() - .bind() - .await?; - - info!("P2P client initialized with node ID: {node_id}"); - - Ok(Self { - endpoint, - node_id, - wallet, - }) - } - - pub fn node_id(&self) -> NodeId { - self.node_id - } - - pub fn endpoint(&self) -> &Endpoint { - &self.endpoint - } - - /// Helper function to write a message with length prefix - async fn write_message(send: &mut SendStream, message: &T) -> Result<()> { - let message_bytes = serde_json::to_vec(message)?; - send.write_all(&(message_bytes.len() as u32).to_be_bytes()) - .await?; - send.write_all(&message_bytes).await?; - Ok(()) - } - - /// Helper function to read a message with length prefix - async fn read_message(recv: &mut RecvStream) -> Result { - let mut len_bytes = [0u8; 4]; - recv.read_exact(&mut len_bytes).await?; - let len = u32::from_be_bytes(len_bytes) as usize; - - let mut message_bytes = vec![0u8; len]; - recv.read_exact(&mut message_bytes).await?; - - let message: T = serde_json::from_slice(&message_bytes)?; - Ok(message) - } - - pub async fn send_request( - &self, - target_p2p_id: &str, - target_addresses: &[String], - target_wallet_address: Address, - message: P2PMessage, - timeout_secs: u64, - ) -> Result { - let timeout_duration = Duration::from_secs(timeout_secs); - - tokio::time::timeout(timeout_duration, async { - self.send_request_inner( - target_p2p_id, - target_addresses, - target_wallet_address, - message, - ) - .await - }) - .await - .map_err(|_| { - anyhow::anyhow!( - "P2P request to {} timed out after {}s", - target_p2p_id, - timeout_secs - ) - })? - } - - async fn send_request_inner( - &self, - target_p2p_id: &str, - target_addresses: &[String], - target_wallet_address: Address, - message: P2PMessage, - ) -> Result { - // Parse target node ID - let node_id = NodeId::from_str(target_p2p_id)?; - - let mut socket_addrs = Vec::new(); - for addr in target_addresses { - if let Ok(socket_addr) = addr.parse() { - socket_addrs.push(socket_addr); - } - } - - if socket_addrs.is_empty() { - return Err(anyhow::anyhow!( - "No valid addresses provided for target node" - )); - } - - // Create node address - let node_addr = NodeAddr::new(node_id).with_direct_addresses(socket_addrs); - - debug!("Connecting to P2P node: {target_p2p_id}"); - - // Connect to the target node - let connection = self.endpoint.connect(node_addr, PRIME_P2P_PROTOCOL).await?; - - let (mut send, mut recv) = connection.open_bi().await?; - - // First request an auth challenge - let challenge_bytes: [u8; 32] = OsRng.gen(); - let challenge_message: String = hex::encode(challenge_bytes); - - let request_auth_challenge = P2PRequest::new(P2PMessage::RequestAuthChallenge { - message: challenge_message.clone(), - }); - Self::write_message(&mut send, &request_auth_challenge).await?; - - // Response contains the auth challenge we have to solve (to show we are the right node) - let auth_challenge_response: P2PResponse = Self::read_message(&mut recv).await?; - let auth_challenge_solution: P2PRequest = match auth_challenge_response.message { - P2PMessage::AuthChallenge { - signed_message, - message, - } => { - // Parse the signature from the server - let Ok(parsed_signature) = alloy::primitives::Signature::from_str(&signed_message) - else { - return Err(anyhow::anyhow!("Failed to parse signature from server")); - }; - - // Recover address from the challenge message that the server signed - let Ok(recovered_address) = - parsed_signature.recover_address_from_msg(&challenge_message) - else { - return Err(anyhow::anyhow!( - "Failed to recover address from server signature" - )); - }; - - // Verify the recovered address matches the expected target wallet address - if recovered_address != target_wallet_address { - return Err(anyhow::anyhow!( - "Server address verification failed: expected {}, got {}", - target_wallet_address, - recovered_address - )); - } - - debug!("Auth challenge received from node: {target_p2p_id}"); - let signature = sign_message(&message, &self.wallet).await.unwrap(); - P2PRequest::new(P2PMessage::AuthSolution { - signed_message: signature, - }) - } - _ => { - return Err(anyhow::anyhow!( - "Expected auth challenge, got different message type" - )); - } - }; - Self::write_message(&mut send, &auth_challenge_solution).await?; - - // Check if we are granted or rejected - let auth_response: P2PResponse = Self::read_message(&mut recv).await?; - match auth_response.message { - P2PMessage::AuthGranted { .. } => { - debug!("Auth granted with node: {target_p2p_id}"); - } - P2PMessage::AuthRejected { .. } => { - debug!("Auth rejected with node: {target_p2p_id}"); - return Err(anyhow::anyhow!( - "Auth rejected with node: {}", - target_p2p_id - )); - } - _ => { - return Err(anyhow::anyhow!( - "Expected auth response, got different message type" - )); - } - } - - // Now send the actual request - let request = P2PRequest::new(message); - Self::write_message(&mut send, &request).await?; - - // Read response - let response: P2PResponse = Self::read_message(&mut recv).await?; - - tokio::time::sleep(Duration::from_millis(50)).await; - - send.finish()?; - - Ok(response.message) - } - - /// Shutdown the P2P client gracefully - pub async fn shutdown(self) -> Result<()> { - info!("Shutting down P2P client with node ID: {}", self.node_id); - self.endpoint.close().await; - Ok(()) - } -} - -impl Drop for P2PClient { - fn drop(&mut self) { - debug!("P2P client dropped for node ID: {}", self.node_id); - } -} diff --git a/crates/shared/src/p2p/messages.rs b/crates/shared/src/p2p/messages.rs deleted file mode 100644 index 1624686a..00000000 --- a/crates/shared/src/p2p/messages.rs +++ /dev/null @@ -1,101 +0,0 @@ -use crate::models::challenge::{ChallengeRequest, ChallengeResponse}; -use crate::models::invite::InviteRequest; -use serde::{Deserialize, Serialize}; -use std::time::SystemTime; - -/// Maximum message size for P2P communication (1MB) -pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; - -/// P2P message types for validator-worker communication -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(tag = "type", content = "payload")] -pub enum P2PMessage { - /// Request auth challenge from worker to validator - RequestAuthChallenge { message: String }, - - /// Auth challenge from worker to validator - AuthChallenge { - signed_message: String, - message: String, - }, - - /// Auth solution from validator to worker - AuthSolution { signed_message: String }, - - /// Auth granted from worker to validator - AuthGranted {}, - - /// Auth rejected from validator to worker - AuthRejected {}, - - /// Simple ping message for connectivity testing - Ping { timestamp: SystemTime, nonce: u64 }, - - /// Response to ping - Pong { timestamp: SystemTime, nonce: u64 }, - - /// Hardware challenge from validator to worker - HardwareChallenge { - challenge: ChallengeRequest, - timestamp: SystemTime, - }, - - /// Hardware challenge response from worker to validator - HardwareChallengeResponse { - response: ChallengeResponse, - timestamp: SystemTime, - }, - - /// Invite request from orchestrator to worker - Invite(InviteRequest), - - /// Response to invite - InviteResponse { - status: String, - error: Option, - }, - - /// Get task logs from worker - GetTaskLogs, - - /// Response with task logs - GetTaskLogsResponse { logs: Result, String> }, - - /// Restart task on worker - RestartTask, - - /// Response to restart task - RestartTaskResponse { result: Result<(), String> }, -} - -/// P2P request wrapper with ID for tracking -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct P2PRequest { - pub id: String, - pub message: P2PMessage, -} - -/// P2P response wrapper with request ID -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct P2PResponse { - pub request_id: String, - pub message: P2PMessage, -} - -impl P2PRequest { - pub fn new(message: P2PMessage) -> Self { - Self { - id: uuid::Uuid::new_v4().to_string(), - message, - } - } -} - -impl P2PResponse { - pub fn new(request_id: String, message: P2PMessage) -> Self { - Self { - request_id, - message, - } - } -} diff --git a/crates/shared/src/p2p/mod.rs b/crates/shared/src/p2p/mod.rs index f505f3b1..9d0e4016 100644 --- a/crates/shared/src/p2p/mod.rs +++ b/crates/shared/src/p2p/mod.rs @@ -1,6 +1,3 @@ -pub mod client; -pub mod messages; -pub mod protocol; +mod service; -pub use client::P2PClient; -pub use protocol::*; +pub use service::*; diff --git a/crates/shared/src/p2p/protocol.rs b/crates/shared/src/p2p/protocol.rs deleted file mode 100644 index 2aab189d..00000000 --- a/crates/shared/src/p2p/protocol.rs +++ /dev/null @@ -1,5 +0,0 @@ -/// Protocol ID for Prime P2P communication -pub const PRIME_P2P_PROTOCOL: &[u8] = b"prime-p2p-v1"; - -/// Timeout for P2P requests in seconds -pub const P2P_REQUEST_TIMEOUT: u64 = 30; diff --git a/crates/shared/src/p2p/service.rs b/crates/shared/src/p2p/service.rs new file mode 100644 index 00000000..9223bc3d --- /dev/null +++ b/crates/shared/src/p2p/service.rs @@ -0,0 +1,435 @@ +use crate::web3::wallet::Wallet; +use anyhow::{bail, Context as _, Result}; +use futures::stream::FuturesUnordered; +use p2p::{ + AuthenticationInitiationRequest, AuthenticationResponse, AuthenticationSolutionRequest, + IncomingMessage, Libp2pIncomingMessage, Node, NodeBuilder, OutgoingMessage, PeerId, Protocol, + Protocols, Response, +}; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; + +pub struct OutgoingRequest { + pub peer_wallet_address: alloy::primitives::Address, + pub request: p2p::Request, + pub peer_id: String, + pub multiaddrs: Vec, + pub response_tx: tokio::sync::oneshot::Sender, +} + +/// A p2p service implementation that is used by the validator and the orchestrator. +/// It handles the authentication protocol used before sending +/// requests to the worker. +pub struct Service { + node: Node, + incoming_messages_rx: Receiver, + outgoing_messages_rx: Receiver, + cancellation_token: CancellationToken, + context: Context, +} + +impl Service { + pub fn new( + keypair: p2p::Keypair, + port: u16, + cancellation_token: CancellationToken, + wallet: Wallet, + protocols: Protocols, + ) -> Result<(Self, Sender)> { + let (node, incoming_messages_rx, outgoing_messages) = + build_p2p_node(keypair, port, cancellation_token.clone(), protocols.clone()) + .context("failed to build p2p node")?; + let (outgoing_messages_tx, outgoing_messages_rx) = tokio::sync::mpsc::channel(100); + + Ok(( + Self { + node, + incoming_messages_rx, + outgoing_messages_rx, + cancellation_token, + context: Context::new(outgoing_messages, wallet, protocols), + }, + outgoing_messages_tx, + )) + } + + pub async fn run(self) { + use futures::StreamExt as _; + + let Self { + node, + mut incoming_messages_rx, + mut outgoing_messages_rx, + cancellation_token, + context, + } = self; + tokio::task::spawn(node.run()); + + let mut incoming_message_handlers = FuturesUnordered::new(); + let mut outgoing_message_handlers = FuturesUnordered::new(); + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break; + } + Some(message) = outgoing_messages_rx.recv() => { + let handle = tokio::task::spawn(handle_outgoing_message(message, context.clone())); + outgoing_message_handlers.push(handle); + } + Some(message) = incoming_messages_rx.recv() => { + let context = context.clone(); + let handle = tokio::task::spawn( + handle_incoming_message(message, context) + ); + incoming_message_handlers.push(handle); + } + Some(res) = incoming_message_handlers.next() => { + if let Err(e) = res { + log::error!("failed to handle incoming message: {e}"); + } + } + Some(res) = outgoing_message_handlers.next() => { + if let Err(e) = res { + log::error!("failed to handle outgoing message: {e}"); + } + } + } + } + } +} + +fn build_p2p_node( + keypair: p2p::Keypair, + port: u16, + cancellation_token: CancellationToken, + protocols: Protocols, +) -> Result<(Node, Receiver, Sender)> { + NodeBuilder::new() + .with_keypair(keypair) + .with_port(port) + .with_authentication() + .with_protocols(protocols) + .with_cancellation_token(cancellation_token) + .try_build() +} + +#[derive(Clone)] +struct Context { + // outbound message channel; receiver is held by libp2p node + outgoing_messages: Sender, + + // ongoing authentication requests + ongoing_auth_requests: Arc>>, + is_authenticated_with_peer: Arc>>, + + // this assumes that there is only one outbound request per protocol per peer at a time, + // is this a correct assumption? + // response channel is for sending the response back to the caller who initiated the request + #[allow(clippy::type_complexity)] + ongoing_outbound_requests: + Arc>>>, + + wallet: Wallet, + protocols: Protocols, +} + +#[derive(Debug)] +struct OngoingAuthChallenge { + peer_wallet_address: alloy::primitives::Address, + auth_challenge_request_message: String, + outgoing_message: p2p::Request, + response_tx: tokio::sync::oneshot::Sender, +} + +impl Context { + fn new( + outgoing_messages: Sender, + wallet: Wallet, + protocols: Protocols, + ) -> Self { + Self { + outgoing_messages, + ongoing_auth_requests: Arc::new(RwLock::new(HashMap::new())), + is_authenticated_with_peer: Arc::new(RwLock::new(HashSet::new())), + ongoing_outbound_requests: Arc::new(RwLock::new(HashMap::new())), + wallet, + protocols, + } + } +} + +async fn handle_outgoing_message(message: OutgoingRequest, context: Context) -> Result<()> { + use rand_v8::rngs::OsRng; + use rand_v8::Rng as _; + use std::str::FromStr as _; + + let OutgoingRequest { + peer_wallet_address, + request, + peer_id, + multiaddrs, + response_tx, + } = message; + + let peer_id = PeerId::from_str(&peer_id).context("failed to parse peer id")?; + + // check if we're authenticated already + let is_authenticated_with_peer = context.is_authenticated_with_peer.read().await; + if is_authenticated_with_peer.contains(&peer_id) { + log::debug!( + "already authenticated with peer {peer_id}, skipping validation authentication" + ); + // multiaddresses are already known, as we've connected to them previously + context + .outgoing_messages + .send(request.into_outgoing_message(peer_id, vec![])) + .await + .context("failed to send outgoing message")?; + return Ok(()); + } + + // ensure there's no ongoing challenge + // use write-lock to make this atomic until we finish sending the auth request and writing to the map + let mut ongoing_auth_requests = context.ongoing_auth_requests.write().await; + if ongoing_auth_requests.contains_key(&peer_id) { + bail!("ongoing auth request for {} already exists", peer_id); + } + + let multiaddrs = multiaddrs + .iter() + .filter_map( + |addr| p2p::Multiaddr::from_str(addr).ok(), /* ?.with_p2p(peer_id).ok()*/ + ) + .collect::>(); + if multiaddrs.is_empty() { + bail!("no valid multiaddrs for peer id {peer_id}"); + } + + // create the authentication challenge request message + let challenge_bytes: [u8; 32] = OsRng.gen(); + let auth_challenge_message: String = hex::encode(challenge_bytes); + + let req: p2p::Request = AuthenticationInitiationRequest { + message: auth_challenge_message.clone(), + } + .into(); + let outgoing_message = req.into_outgoing_message(peer_id, multiaddrs); + log::debug!("sending ValidatorAuthenticationInitiationRequest to {peer_id}"); + context + .outgoing_messages + .send(outgoing_message) + .await + .context("failed to send outgoing message")?; + + // store the ongoing auth challenge + let ongoing_challenge = OngoingAuthChallenge { + peer_wallet_address, + auth_challenge_request_message: auth_challenge_message.clone(), + outgoing_message: request, + response_tx, + }; + + ongoing_auth_requests.insert(peer_id, ongoing_challenge); + Ok(()) +} + +async fn handle_incoming_message(message: IncomingMessage, context: Context) -> Result<()> { + match message.message { + Libp2pIncomingMessage::Request { + request_id: _, + request, + channel: _, + } => { + log::error!( + "node should not receive incoming requests: {request:?} from {}", + message.peer + ); + } + Libp2pIncomingMessage::Response { + request_id: _, + response, + } => { + log::debug!("received incoming response {response:?}"); + handle_incoming_response(message.peer, response, context) + .await + .context("failed to handle incoming response")?; + } + } + Ok(()) +} + +async fn handle_incoming_response( + from: PeerId, + response: p2p::Response, + context: Context, +) -> Result<()> { + match response { + p2p::Response::Authentication(resp) => { + log::debug!("received ValidatorAuthenticationSolutionResponse from {from}: {resp:?}"); + handle_validation_authentication_response(from, resp, context) + .await + .context("failed to handle validator authentication response")?; + } + p2p::Response::HardwareChallenge(ref resp) => { + if !context.protocols.has_hardware_challenge() { + bail!("received HardwareChallengeResponse from {from}, but hardware challenge protocol is not enabled"); + } + + log::debug!("received HardwareChallengeResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = + ongoing_outbound_requests.remove(&(from, Protocol::HardwareChallenge)) + else { + bail!( + "no ongoing hardware challenge for peer {from}, cannot handle HardwareChallengeResponse" + ); + }; + let _ = response_tx.send(response); + } + p2p::Response::Invite(ref resp) => { + if !context.protocols.has_invite() { + bail!("received InviteResponse from {from}, but invite protocol is not enabled"); + } + + log::debug!("received InviteResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = ongoing_outbound_requests.remove(&(from, Protocol::Invite)) + else { + bail!("no ongoing invite for peer {from}, cannot handle InviteResponse"); + }; + let _ = response_tx.send(response); + } + p2p::Response::GetTaskLogs(ref resp) => { + if !context.protocols.has_get_task_logs() { + bail!("received GetTaskLogsResponse from {from}, but get task logs protocol is not enabled"); + } + + log::debug!("received GetTaskLogsResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = + ongoing_outbound_requests.remove(&(from, Protocol::GetTaskLogs)) + else { + bail!("no ongoing GetTaskLogs for peer {from}, cannot handle GetTaskLogsResponse"); + }; + let _ = response_tx.send(response); + } + p2p::Response::RestartTask(ref resp) => { + if !context.protocols.has_restart() { + bail!("received RestartResponse from {from}, but restart protocol is not enabled"); + } + + log::debug!("received RestartResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = ongoing_outbound_requests.remove(&(from, Protocol::Restart)) + else { + bail!("no ongoing Restart for peer {from}, cannot handle RestartResponse"); + }; + let _ = response_tx.send(response); + } + p2p::Response::General(ref resp) => { + if !context.protocols.has_general() { + bail!("received GeneralResponse from {from}, but general protocol is not enabled"); + } + + log::debug!("received GeneralResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = ongoing_outbound_requests.remove(&(from, Protocol::General)) + else { + bail!("no ongoing General for peer {from}, cannot handle GeneralResponse"); + }; + let _ = response_tx.send(response); + } + } + + Ok(()) +} + +async fn handle_validation_authentication_response( + from: PeerId, + response: p2p::AuthenticationResponse, + context: Context, +) -> Result<()> { + use crate::security::request_signer::sign_message; + use std::str::FromStr as _; + + match response { + AuthenticationResponse::Initiation(req) => { + let ongoing_auth_requests = context.ongoing_auth_requests.read().await; + let Some(ongoing_challenge) = ongoing_auth_requests.get(&from) else { + bail!( + "no ongoing hardware challenge for peer {from}, cannot handle ValidatorAuthenticationInitiationResponse" + ); + }; + + let Ok(parsed_signature) = alloy::primitives::Signature::from_str(&req.signature) + else { + bail!("failed to parse signature from response"); + }; + + // recover address from the challenge message that the peer signed + let Ok(recovered_address) = parsed_signature + .recover_address_from_msg(&ongoing_challenge.auth_challenge_request_message) + else { + bail!("Failed to recover address from response signature") + }; + + // verify the recovered address matches the expected worker wallet address + if recovered_address != ongoing_challenge.peer_wallet_address { + bail!( + "peer address verification failed: expected {}, got {recovered_address}", + ongoing_challenge.peer_wallet_address, + ) + } + + log::debug!("auth challenge initiation response received from node: {from}"); + let signature = sign_message(&req.message, &context.wallet).await.unwrap(); + + let req: p2p::Request = AuthenticationSolutionRequest { signature }.into(); + let req = req.into_outgoing_message(from, vec![]); + context + .outgoing_messages + .send(req) + .await + .context("failed to send outgoing message")?; + } + AuthenticationResponse::Solution(req) => { + let mut ongoing_auth_requests = context.ongoing_auth_requests.write().await; + let Some(ongoing_challenge) = ongoing_auth_requests.remove(&from) else { + bail!( + "no ongoing hardware challenge for peer {from}, cannot handle ValidatorAuthenticationSolutionResponse" + ); + }; + + match req { + p2p::AuthenticationSolutionResponse::Granted => {} + p2p::AuthenticationSolutionResponse::Rejected => { + log::debug!("auth challenge rejected by node: {from}"); + return Ok(()); + } + } + + // auth was granted, finally send the hardware challenge + let mut is_authenticated_with_peer = context.is_authenticated_with_peer.write().await; + is_authenticated_with_peer.insert(from); + + let protocol = ongoing_challenge.outgoing_message.protocol(); + let req = ongoing_challenge + .outgoing_message + .into_outgoing_message(from, vec![]); + context + .outgoing_messages + .send(req) + .await + .context("failed to send outgoing message")?; + + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + ongoing_outbound_requests.insert((from, protocol), ongoing_challenge.response_tx); + } + } + Ok(()) +} diff --git a/crates/shared/src/security/auth_signature_middleware.rs b/crates/shared/src/security/auth_signature_middleware.rs index 1c4c1e10..8ba7767e 100644 --- a/crates/shared/src/security/auth_signature_middleware.rs +++ b/crates/shared/src/security/auth_signature_middleware.rs @@ -634,10 +634,10 @@ mod tests { .await; log::info!("Address: {}", wallet.wallet.default_signer().address()); - log::info!("Signature: {}", signature); - log::info!("Nonce: {}", nonce); + log::info!("Signature: {signature}"); + log::info!("Nonce: {nonce}"); let req = test::TestRequest::get() - .uri(&format!("/test?nonce={}", nonce)) + .uri(&format!("/test?nonce={nonce}")) .insert_header(( "x-address", wallet.wallet.default_signer().address().to_string(), @@ -801,8 +801,7 @@ mod tests { // Create multiple addresses let addresses: Vec
= (0..5) .map(|i| { - Address::from_str(&format!("0x{}000000000000000000000000000000000000000", i)) - .unwrap() + Address::from_str(&format!("0x{i}000000000000000000000000000000000000000")).unwrap() }) .collect(); diff --git a/crates/shared/src/security/request_signer.rs b/crates/shared/src/security/request_signer.rs index ff3e9964..c5ea3605 100644 --- a/crates/shared/src/security/request_signer.rs +++ b/crates/shared/src/security/request_signer.rs @@ -143,7 +143,7 @@ mod tests { let signature = sign_request(endpoint, &wallet, Some(&empty_data)) .await .unwrap(); - println!("Signature: {}", signature); + println!("Signature: {signature}"); assert!(signature.starts_with("0x")); assert_eq!(signature.len(), 132); } diff --git a/crates/shared/src/utils/google_cloud.rs b/crates/shared/src/utils/google_cloud.rs index 128259eb..72fae856 100644 --- a/crates/shared/src/utils/google_cloud.rs +++ b/crates/shared/src/utils/google_cloud.rs @@ -194,20 +194,14 @@ mod tests { #[tokio::test] async fn test_generate_mapping_file() { // Check if required environment variables are set - let bucket_name = match std::env::var("S3_BUCKET_NAME") { - Ok(name) => name, - Err(_) => { - println!("Skipping test: BUCKET_NAME not set"); - return; - } + let Ok(bucket_name) = std::env::var("S3_BUCKET_NAME") else { + println!("Skipping test: BUCKET_NAME not set"); + return; }; - let credentials_base64 = match std::env::var("S3_CREDENTIALS") { - Ok(credentials) => credentials, - Err(_) => { - println!("Skipping test: S3_CREDENTIALS not set"); - return; - } + let Ok(credentials_base64) = std::env::var("S3_CREDENTIALS") else { + println!("Skipping test: S3_CREDENTIALS not set"); + return; }; let storage = GcsStorageProvider::new(&bucket_name, &credentials_base64) @@ -219,15 +213,15 @@ mod tests { .generate_mapping_file(&random_sha256, "run_1/file.parquet") .await .unwrap(); - println!("mapping_content: {}", mapping_content); - println!("bucket_name: {}", bucket_name); + println!("mapping_content: {mapping_content}"); + println!("bucket_name: {bucket_name}"); let original_file_name = storage .resolve_mapping_for_sha(&random_sha256) .await .unwrap(); - println!("original_file_name: {}", original_file_name); + println!("original_file_name: {original_file_name}"); assert_eq!(original_file_name, "run_1/file.parquet"); } } diff --git a/crates/shared/src/utils/mod.rs b/crates/shared/src/utils/mod.rs index d4e3f1c9..290f1ae5 100644 --- a/crates/shared/src/utils/mod.rs +++ b/crates/shared/src/utils/mod.rs @@ -119,7 +119,7 @@ mod tests { provider.add_mapping_file("sha256", "file.txt").await; provider.add_file("file.txt", "content").await; let map_file_link = provider.resolve_mapping_for_sha("sha256").await.unwrap(); - println!("map_file_link: {}", map_file_link); + println!("map_file_link: {map_file_link}"); assert_eq!(map_file_link, "file.txt"); assert_eq!( diff --git a/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs b/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs index ff0a20ce..b52f96e2 100644 --- a/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs +++ b/crates/shared/src/web3/contracts/implementations/compute_pool_contract.rs @@ -29,6 +29,7 @@ impl ComputePool

{ .function("getComputePool", &[pool_id.into()])? .call() .await?; + let pool_info_tuple: &[DynSolValue] = pool_info_response.first().unwrap().as_tuple().unwrap(); @@ -60,6 +61,9 @@ impl ComputePool

{ _ => panic!("Unknown status value: {status}"), }; + println!("Mapped status: {mapped_status:?}"); + println!("Returning pool info"); + let pool_info = PoolInfo { pool_id, domain_id, diff --git a/crates/validator/Cargo.toml b/crates/validator/Cargo.toml index db3694ca..4d329921 100644 --- a/crates/validator/Cargo.toml +++ b/crates/validator/Cargo.toml @@ -7,6 +7,9 @@ edition.workspace = true workspace = true [dependencies] +shared = { workspace = true } +p2p = { workspace = true} + actix-web = { workspace = true } alloy = { workspace = true } anyhow = { workspace = true } @@ -16,25 +19,21 @@ directories = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } hex = { workspace = true } -iroh = { workspace = true } -rand_v8 = { workspace = true } -lazy_static = "1.5.0" log = { workspace = true } -nalgebra = { workspace = true } -prometheus = "0.14.0" -rand = "0.9.0" redis = { workspace = true, features = ["tokio-comp"] } -redis-test = { workspace = true } -regex = "1.11.1" reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -shared = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } -toml = { workspace = true } url = { workspace = true } +lazy_static = "1.5.0" +prometheus = "0.14.0" +rand = "0.9.0" +regex = "1.11.1" + [dev-dependencies] mockito = { workspace = true } +redis-test = { workspace = true } tempfile = "=3.14.0" diff --git a/crates/validator/src/lib.rs b/crates/validator/src/lib.rs index 760af2d1..9fac5ce8 100644 --- a/crates/validator/src/lib.rs +++ b/crates/validator/src/lib.rs @@ -5,7 +5,7 @@ mod validators; pub use metrics::export_metrics; pub use metrics::MetricsContext; -pub use p2p::P2PClient; +pub use p2p::Service as P2PService; pub use store::redis::RedisStore; pub use validators::hardware::HardwareValidator; pub use validators::synthetic_data::types::InvalidationType; diff --git a/crates/validator/src/main.rs b/crates/validator/src/main.rs index 55b3900d..f3b80d4b 100644 --- a/crates/validator/src/main.rs +++ b/crates/validator/src/main.rs @@ -23,7 +23,7 @@ use tokio_util::sync::CancellationToken; use url::Url; use validator::{ - export_metrics, HardwareValidator, InvalidationType, MetricsContext, P2PClient, RedisStore, + export_metrics, HardwareValidator, InvalidationType, MetricsContext, P2PService, RedisStore, SyntheticDataValidator, }; @@ -196,6 +196,10 @@ struct Args { /// Redis URL #[arg(long, default_value = "redis://localhost:6380")] redis_url: String, + + /// Libp2p port + #[arg(long, default_value = "4003")] + libp2p_port: u16, } #[tokio::main] @@ -269,19 +273,27 @@ async fn main() -> anyhow::Result<()> { MetricsContext::new(validator_wallet.address().to_string(), args.pool_id.clone()); // Initialize P2P client if enabled - let p2p_client = { - match P2PClient::new(validator_wallet.clone()).await { - Ok(client) => { - info!("P2P client initialized for testing"); - Some(client) + let keypair = p2p::Keypair::generate_ed25519(); + let (p2p_service, hardware_challenge_tx) = { + match P2PService::new( + keypair, + args.libp2p_port, + cancellation_token.clone(), + validator_wallet.clone(), + ) { + Ok(res) => { + info!("p2p service initialized successfully"); + res } Err(e) => { - error!("Failed to initialize P2P client: {e}"); - None + error!("failed to initialize p2p service: {e}"); + std::process::exit(1); } } }; + tokio::task::spawn(p2p_service.run()); + if let Some(pool_id) = args.pool_id.clone() { let pool = match contracts .compute_pool @@ -308,8 +320,7 @@ async fn main() -> anyhow::Result<()> { let contracts = contract_builder.build().unwrap(); - let hardware_validator = - HardwareValidator::new(&validator_wallet, contracts.clone(), p2p_client.as_ref()); + let hardware_validator = HardwareValidator::new(contracts.clone(), hardware_challenge_tx); let synthetic_validator = if let Some(pool_id) = args.pool_id.clone() { let penalty = U256::from(args.validator_penalty) * Unit::ETHER.wei(); @@ -628,7 +639,7 @@ mod tests { web::{self, post}, HttpResponse, Scope, }; - use shared::models::challenge::{calc_matrix, ChallengeRequest, ChallengeResponse, FixedF64}; + use p2p::{calc_matrix, ChallengeRequest, ChallengeResponse, FixedF64}; async fn handle_challenge(challenge: web::Json) -> HttpResponse { let result = calc_matrix(&challenge); diff --git a/crates/validator/src/p2p/client.rs b/crates/validator/src/p2p/client.rs deleted file mode 100644 index a0b21db1..00000000 --- a/crates/validator/src/p2p/client.rs +++ /dev/null @@ -1,89 +0,0 @@ -use alloy::primitives::Address; -use anyhow::Result; -use log::info; -use rand_v8::Rng; -use shared::models::challenge::{ChallengeRequest, ChallengeResponse}; -use shared::p2p::{client::P2PClient as SharedP2PClient, messages::P2PMessage}; -use shared::web3::wallet::Wallet; -use std::time::SystemTime; - -pub struct P2PClient { - shared_client: SharedP2PClient, -} - -impl P2PClient { - pub async fn new(wallet: Wallet) -> Result { - let shared_client = SharedP2PClient::new(wallet).await?; - Ok(Self { shared_client }) - } - - pub async fn ping_worker( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - ) -> Result { - let nonce = rand_v8::thread_rng().gen::(); - - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::Ping { - timestamp: SystemTime::now(), - nonce, - }, - 10, - ) - .await?; - - match response { - P2PMessage::Pong { - nonce: returned_nonce, - .. - } => { - if returned_nonce == nonce { - info!("Received valid pong from worker {worker_p2p_id} with nonce: {nonce}"); - Ok(nonce) - } else { - Err(anyhow::anyhow!("Invalid nonce in pong response")) - } - } - _ => Err(anyhow::anyhow!("Unexpected response type for ping")), - } - } - - pub async fn send_hardware_challenge( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - challenge: ChallengeRequest, - ) -> Result { - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::HardwareChallenge { - challenge, - timestamp: SystemTime::now(), - }, - 30, - ) - .await?; - - match response { - P2PMessage::HardwareChallengeResponse { response, .. } => { - info!("Received hardware challenge response from worker {worker_p2p_id}"); - Ok(response) - } - _ => Err(anyhow::anyhow!( - "Unexpected response type for hardware challenge" - )), - } - } -} diff --git a/crates/validator/src/p2p/mod.rs b/crates/validator/src/p2p/mod.rs index 33dad50c..6fa8fac7 100644 --- a/crates/validator/src/p2p/mod.rs +++ b/crates/validator/src/p2p/mod.rs @@ -1,3 +1,103 @@ -pub(crate) mod client; +use anyhow::{bail, Context as _, Result}; +use futures::stream::FuturesUnordered; +use p2p::{Keypair, Protocols}; +use shared::p2p::OutgoingRequest; +use shared::p2p::Service as P2PService; +use shared::web3::wallet::Wallet; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio_util::sync::CancellationToken; -pub use client::P2PClient; +pub struct Service { + inner: P2PService, + + // converts incoming hardware challenges to outgoing requests + outgoing_message_tx: Sender, + hardware_challenge_rx: Receiver, +} + +impl Service { + pub fn new( + keypair: Keypair, + port: u16, + cancellation_token: CancellationToken, + wallet: Wallet, + ) -> Result<(Self, Sender)> { + let (hardware_challenge_tx, hardware_challenge_rx) = tokio::sync::mpsc::channel(100); + let (inner, outgoing_message_tx) = P2PService::new( + keypair, + port, + cancellation_token.clone(), + wallet, + Protocols::new() + .with_hardware_challenge() + .with_authentication(), + ) + .context("failed to create p2p service")?; + Ok(( + Self { + inner, + outgoing_message_tx, + hardware_challenge_rx, + }, + hardware_challenge_tx, + )) + } + + pub async fn run(self) -> Result<()> { + use futures::StreamExt as _; + + let Self { + inner, + outgoing_message_tx, + mut hardware_challenge_rx, + } = self; + + tokio::task::spawn(inner.run()); + + let mut futures = FuturesUnordered::new(); + + loop { + tokio::select! { + Some(request) = hardware_challenge_rx.recv() => { + println!("p2p: got hardware challenge"); + let (incoming_resp_tx, incoming_resp_rx) = tokio::sync::oneshot::channel(); + let fut = async move { + let resp = match incoming_resp_rx.await.context("outgoing request tx channel was dropped")? { + p2p::Response::HardwareChallenge(resp) => resp.response, + _ => bail!("unexpected response type for hardware challenge request"), + }; + request.response_tx.send(resp).map_err(|_|anyhow::anyhow!("caller dropped response channel"))?; + Ok(()) + }; + futures.push(fut); + + let outgoing_request = OutgoingRequest { + peer_wallet_address: request.worker_wallet_address, + peer_id: request.worker_p2p_id, + multiaddrs: request.worker_addresses, + request: p2p::HardwareChallengeRequest { + challenge: request.challenge, + timestamp: std::time::SystemTime::now(), + }.into(), + response_tx: incoming_resp_tx, + }; + outgoing_message_tx.send(outgoing_request).await + .context("failed to send outgoing hardware challenge request")?; + } + Some(res) = futures.next() => { + if let Err(e) = res { + log::error!("failed to handle response conversion: {e}"); + } + } + } + } + } +} + +pub struct HardwareChallengeRequest { + pub(crate) worker_wallet_address: alloy::primitives::Address, + pub(crate) worker_p2p_id: String, + pub(crate) worker_addresses: Vec, + pub(crate) challenge: p2p::ChallengeRequest, + pub(crate) response_tx: tokio::sync::oneshot::Sender, +} diff --git a/crates/validator/src/store/redis.rs b/crates/validator/src/store/redis.rs index 508815c2..c0a0c36b 100644 --- a/crates/validator/src/store/redis.rs +++ b/crates/validator/src/store/redis.rs @@ -45,8 +45,8 @@ impl RedisStore { _ => panic!("Expected TCP connection"), }; - let redis_url = format!("redis://{}:{}", host, port); - debug!("Starting test Redis server at {}", redis_url); + let redis_url = format!("redis://{host}:{port}"); + debug!("Starting test Redis server at {redis_url}"); // Add a small delay to ensure server is ready thread::sleep(Duration::from_millis(100)); diff --git a/crates/validator/src/validators/hardware.rs b/crates/validator/src/validators/hardware.rs index 00736d34..da5307e3 100644 --- a/crates/validator/src/validators/hardware.rs +++ b/crates/validator/src/validators/hardware.rs @@ -1,15 +1,13 @@ use alloy::primitives::Address; +use anyhow::bail; use anyhow::Result; use log::{debug, error, info}; use shared::{ models::node::DiscoveryNode, - web3::{ - contracts::core::builder::Contracts, - wallet::{Wallet, WalletProvider}, - }, + web3::{contracts::core::builder::Contracts, wallet::WalletProvider}, }; -use crate::p2p::client::P2PClient; +use crate::p2p::HardwareChallengeRequest; use crate::validators::hardware_challenge::HardwareChallenge; /// Hardware validator implementation @@ -17,35 +15,27 @@ use crate::validators::hardware_challenge::HardwareChallenge; /// NOTE: This is a temporary implementation that will be replaced with a proper /// hardware validator in the near future. The current implementation only performs /// basic matrix multiplication challenges and does not verify actual hardware specs. -pub struct HardwareValidator<'a> { - wallet: &'a Wallet, +pub struct HardwareValidator { contracts: Contracts, - p2p_client: Option<&'a P2PClient>, + challenge_tx: tokio::sync::mpsc::Sender, } -impl<'a> HardwareValidator<'a> { +impl HardwareValidator { pub fn new( - wallet: &'a Wallet, contracts: Contracts, - p2p_client: Option<&'a P2PClient>, + challenge_tx: tokio::sync::mpsc::Sender, ) -> Self { Self { - wallet, contracts, - p2p_client, + challenge_tx, } } - async fn validate_node( - _wallet: &'a Wallet, - contracts: Contracts, - p2p_client: Option<&'a P2PClient>, - node: DiscoveryNode, - ) -> Result<()> { + async fn validate_node(&self, node: DiscoveryNode) -> Result<()> { let node_address = match node.id.trim_start_matches("0x").parse::

() { Ok(addr) => addr, Err(e) => { - return Err(anyhow::anyhow!("Failed to parse node address: {}", e)); + bail!("failed to parse node address: {e:?}"); } }; @@ -56,30 +46,22 @@ impl<'a> HardwareValidator<'a> { { Ok(addr) => addr, Err(e) => { - return Err(anyhow::anyhow!("Failed to parse provider address: {}", e)); + bail!("failed to parse provider address: {e:?}"); } }; // Perform hardware challenge - if let Some(p2p_client) = p2p_client { - let hardware_challenge = HardwareChallenge::new(p2p_client); - let challenge_result = hardware_challenge.challenge_node(&node).await; - - if let Err(e) = challenge_result { - println!("Challenge failed for node: {}, error: {}", node.id, e); - error!("Challenge failed for node: {}, error: {}", node.id, e); - return Err(anyhow::anyhow!("Failed to challenge node: {}", e)); - } - } else { - debug!( - "P2P client not available, skipping hardware challenge for node {}", - node.id - ); + let hardware_challenge = HardwareChallenge::new(self.challenge_tx.clone()); + let challenge_result = hardware_challenge.challenge_node(&node).await; + + if let Err(e) = challenge_result { + bail!("failed to challenge node: {e:?}"); } debug!("Sending validation transaction for node {}", node.id); - if let Err(e) = contracts + if let Err(e) = self + .contracts .prime_network .validate_node(provider_address, node_address) .await @@ -100,17 +82,11 @@ impl<'a> HardwareValidator<'a> { debug!("Non validated nodes: {non_validated:?}"); info!("Starting validation for {} nodes", non_validated.len()); - let contracts = self.contracts.clone(); - let wallet = self.wallet; - let p2p_client = self.p2p_client; - // Process non validated nodes sequentially as simple fix // to avoid nonce conflicts for now. Will sophisticate this in the future for node in non_validated { let node_id = node.id.clone(); - match HardwareValidator::validate_node(wallet, contracts.clone(), p2p_client, node) - .await - { + match self.validate_node(node).await { Ok(_) => (), Err(e) => { error!("Failed to validate node {node_id}: {e}"); @@ -134,7 +110,6 @@ mod tests { async fn test_challenge_node() { let coordinator_key = "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97"; let rpc_url: Url = Url::parse("http://localhost:8545").unwrap(); - let coordinator_wallet = Arc::new(Wallet::new(coordinator_key, rpc_url).unwrap()); let contracts = ContractBuilder::new(coordinator_wallet.provider()) @@ -145,7 +120,8 @@ mod tests { .build() .unwrap(); - let validator = HardwareValidator::new(&coordinator_wallet, contracts, None); + let (tx, _rx) = tokio::sync::mpsc::channel(100); + let validator = HardwareValidator::new(contracts, tx); let fake_discovery_node1 = DiscoveryNode { is_validated: false, @@ -185,7 +161,7 @@ mod tests { let result = validator.validate_nodes(nodes).await; let elapsed = start_time.elapsed(); assert!(elapsed < std::time::Duration::from_secs(11)); - println!("Validation took: {:?}", elapsed); + println!("Validation took: {elapsed:?}"); assert!(result.is_ok()); } diff --git a/crates/validator/src/validators/hardware_challenge.rs b/crates/validator/src/validators/hardware_challenge.rs index c881c542..6970355d 100644 --- a/crates/validator/src/validators/hardware_challenge.rs +++ b/crates/validator/src/validators/hardware_challenge.rs @@ -1,40 +1,38 @@ -use crate::p2p::client::P2PClient; use alloy::primitives::Address; -use anyhow::{Error, Result}; +use anyhow::{bail, Context as _, Result}; use log::{error, info}; use rand::{rng, Rng}; -use shared::models::{ - challenge::{calc_matrix, ChallengeRequest, FixedF64}, - node::DiscoveryNode, -}; +use shared::models::node::DiscoveryNode; use std::str::FromStr; -pub(crate) struct HardwareChallenge<'a> { - p2p_client: &'a P2PClient, +use crate::p2p::HardwareChallengeRequest; + +pub(crate) struct HardwareChallenge { + challenge_tx: tokio::sync::mpsc::Sender, } -impl<'a> HardwareChallenge<'a> { - pub(crate) fn new(p2p_client: &'a P2PClient) -> Self { - Self { p2p_client } +impl HardwareChallenge { + pub(crate) fn new(challenge_tx: tokio::sync::mpsc::Sender) -> Self { + Self { challenge_tx } } - pub(crate) async fn challenge_node(&self, node: &DiscoveryNode) -> Result { + pub(crate) async fn challenge_node(&self, node: &DiscoveryNode) -> Result<()> { // Check if node has P2P ID and addresses let p2p_id = node .node .worker_p2p_id - .as_ref() + .clone() .ok_or_else(|| anyhow::anyhow!("Node {} does not have P2P ID", node.id))?; let p2p_addresses = node .node .worker_p2p_addresses - .as_ref() + .clone() .ok_or_else(|| anyhow::anyhow!("Node {} does not have P2P addresses", node.id))?; // create random challenge matrix let challenge_matrix = self.random_challenge(3, 3, 3, 3); - let challenge_expected = calc_matrix(&challenge_matrix); + let challenge_expected = p2p::calc_matrix(&challenge_matrix); // Add timestamp to the challenge let current_time = std::time::SystemTime::now() @@ -47,34 +45,36 @@ impl<'a> HardwareChallenge<'a> { let node_address = Address::from_str(&node.node.id) .map_err(|e| anyhow::anyhow!("Failed to parse node address {}: {}", node.node.id, e))?; + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let hardware_challenge = HardwareChallengeRequest { + worker_wallet_address: node_address, + worker_p2p_id: p2p_id, + worker_addresses: p2p_addresses, + challenge: challenge_with_timestamp, + response_tx, + }; + // Send challenge via P2P - match self - .p2p_client - .send_hardware_challenge( - node_address, - p2p_id, - p2p_addresses, - challenge_with_timestamp, - ) + self.challenge_tx + .send(hardware_challenge) + .await + .context("failed to send hardware challenge request to p2p service")?; + + let resp = response_rx .await - { - Ok(response) => { - if challenge_expected.result == response.result { - info!("Challenge for node {} successful", node.id); - Ok(0) - } else { - error!( - "Challenge failed for node {}: expected {:?}, got {:?}", - node.id, challenge_expected.result, response.result - ); - Err(anyhow::anyhow!("Node failed challenge")) - } - } - Err(e) => { - error!("Failed to send challenge to node {}: {}", node.id, e); - Err(anyhow::anyhow!("Failed to send challenge: {}", e)) - } + .context("failed to receive response from node")?; + + if challenge_expected.result == resp.result { + info!("Challenge for node {} successful", node.id); + } else { + error!( + "Challenge failed for node {}: expected {:?}, got {:?}", + node.id, challenge_expected.result, resp.result + ); + bail!("Node failed challenge"); } + + Ok(()) } fn random_challenge( @@ -83,7 +83,9 @@ impl<'a> HardwareChallenge<'a> { cols_a: usize, rows_b: usize, cols_b: usize, - ) -> ChallengeRequest { + ) -> p2p::ChallengeRequest { + use p2p::FixedF64; + let mut rng = rng(); let data_a_vec: Vec = (0..(rows_a * cols_a)) @@ -98,7 +100,7 @@ impl<'a> HardwareChallenge<'a> { let data_a: Vec = data_a_vec.iter().map(|x| FixedF64(*x)).collect(); let data_b: Vec = data_b_vec.iter().map(|x| FixedF64(*x)).collect(); - ChallengeRequest { + p2p::ChallengeRequest { rows_a, cols_a, data_a, diff --git a/crates/validator/src/validators/synthetic_data/chain_operations.rs b/crates/validator/src/validators/synthetic_data/chain_operations.rs index 004c7e45..a0687d18 100644 --- a/crates/validator/src/validators/synthetic_data/chain_operations.rs +++ b/crates/validator/src/validators/synthetic_data/chain_operations.rs @@ -3,7 +3,7 @@ use super::*; impl SyntheticDataValidator { #[cfg(test)] pub fn soft_invalidate_work(&self, work_key: &str) -> Result<(), Error> { - info!("Soft invalidating work: {}", work_key); + info!("Soft invalidating work: {work_key}"); if self.disable_chain_invalidation { info!("Chain invalidation is disabled, skipping work soft invalidation"); @@ -54,7 +54,7 @@ impl SyntheticDataValidator { #[cfg(test)] pub fn invalidate_work(&self, work_key: &str) -> Result<(), Error> { - info!("Invalidating work: {}", work_key); + info!("Invalidating work: {work_key}"); if let Some(metrics) = &self.metrics { metrics.record_work_key_invalidation(); @@ -98,20 +98,27 @@ impl SyntheticDataValidator { } } } - + #[cfg(test)] + #[allow(clippy::unused_async)] pub async fn invalidate_according_to_invalidation_type( &self, work_key: &str, invalidation_type: InvalidationType, ) -> Result<(), Error> { match invalidation_type { - #[cfg(test)] InvalidationType::Soft => self.soft_invalidate_work(work_key), - #[cfg(not(test))] - InvalidationType::Soft => self.soft_invalidate_work(work_key).await, - #[cfg(test)] InvalidationType::Hard => self.invalidate_work(work_key), - #[cfg(not(test))] + } + } + + #[cfg(not(test))] + pub async fn invalidate_according_to_invalidation_type( + &self, + work_key: &str, + invalidation_type: InvalidationType, + ) -> Result<(), Error> { + match invalidation_type { + InvalidationType::Soft => self.soft_invalidate_work(work_key).await, InvalidationType::Hard => self.invalidate_work(work_key).await, } } diff --git a/crates/validator/src/validators/synthetic_data/mod.rs b/crates/validator/src/validators/synthetic_data/mod.rs index ce472c8b..bf8ce6e2 100644 --- a/crates/validator/src/validators/synthetic_data/mod.rs +++ b/crates/validator/src/validators/synthetic_data/mod.rs @@ -237,7 +237,7 @@ impl SyntheticDataValidator { let score: Option = con .zscore("incomplete_groups", group_key) .await - .map_err(|e| Error::msg(format!("Failed to check incomplete tracking: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to check incomplete tracking: {e}")))?; Ok(score.is_some()) } @@ -270,13 +270,10 @@ impl SyntheticDataValidator { let _: () = con .zadd("incomplete_groups", group_key, new_deadline) .await - .map_err(|e| { - Error::msg(format!("Failed to update incomplete group deadline: {}", e)) - })?; + .map_err(|e| Error::msg(format!("Failed to update incomplete group deadline: {e}")))?; debug!( - "Updated deadline for incomplete group {} to {} ({} minutes from now)", - group_key, new_deadline, minutes_from_now + "Updated deadline for incomplete group {group_key} to {new_deadline} ({minutes_from_now} minutes from now)" ); Ok(()) @@ -420,7 +417,7 @@ impl SyntheticDataValidator { let data: Option = con .get(key) .await - .map_err(|e| Error::msg(format!("Failed to get work validation status: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to get work validation status: {e}")))?; match data { Some(data) => { @@ -435,8 +432,7 @@ impl SyntheticDataValidator { reason: None, })), Err(e) => Err(Error::msg(format!( - "Failed to parse work validation data: {}", - e + "Failed to parse work validation data: {e}" ))), } } @@ -1576,8 +1572,7 @@ impl SyntheticDataValidator { .await { error!( - "Failed to update work validation status for {}: {}", - work_key, e + "Failed to update work validation status for {work_key}: {e}" ); } } diff --git a/crates/validator/src/validators/synthetic_data/tests/mod.rs b/crates/validator/src/validators/synthetic_data/tests/mod.rs index a589076f..48aaee85 100644 --- a/crates/validator/src/validators/synthetic_data/tests/mod.rs +++ b/crates/validator/src/validators/synthetic_data/tests/mod.rs @@ -34,7 +34,7 @@ fn setup_test_env() -> Result<(RedisStore, Contracts), Error> { "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", url, ) - .map_err(|e| Error::msg(format!("Failed to create demo wallet: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to create demo wallet: {e}")))?; let contracts = ContractBuilder::new(demo_wallet.provider()) .with_compute_registry() @@ -45,7 +45,7 @@ fn setup_test_env() -> Result<(RedisStore, Contracts), Error> { .with_stake_manager() .with_synthetic_data_validator(Some(Address::ZERO)) .build() - .map_err(|e| Error::msg(format!("Failed to build contracts: {}", e)))?; + .map_err(|e| Error::msg(format!("Failed to build contracts: {e}")))?; Ok((store, contracts)) } @@ -197,8 +197,8 @@ async fn test_status_update() -> Result<(), Error> { ) .await .map_err(|e| { - error!("Failed to update work validation status: {}", e); - Error::msg(format!("Failed to update work validation status: {}", e)) + error!("Failed to update work validation status: {e}"); + Error::msg(format!("Failed to update work validation status: {e}")) })?; tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; @@ -206,8 +206,8 @@ async fn test_status_update() -> Result<(), Error> { .get_work_validation_status_from_redis("0x0000000000000000000000000000000000000000") .await .map_err(|e| { - error!("Failed to get work validation status: {}", e); - Error::msg(format!("Failed to get work validation status: {}", e)) + error!("Failed to get work validation status: {e}"); + Error::msg(format!("Failed to get work validation status: {e}")) })?; assert_eq!(status, Some(ValidationResult::Accept)); Ok(()) @@ -344,20 +344,20 @@ async fn test_group_e2e_accept() -> Result<(), Error> { let mock_storage = MockStorageProvider::new(); mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-1-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-1-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-1-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-1-0-0.parquet"), ) .await; server .mock( "POST", - format!("/validategroup/dataset/samplingn-{}-1-0.parquet", GROUP_ID).as_str(), + format!("/validategroup/dataset/samplingn-{GROUP_ID}-1-0.parquet").as_str(), ) .match_body(mockito::Matcher::Json(serde_json::json!({ "file_shas": [FILE_SHA], @@ -371,7 +371,7 @@ async fn test_group_e2e_accept() -> Result<(), Error> { server .mock( "GET", - format!("/statusgroup/dataset/samplingn-{}-1-0.parquet", GROUP_ID).as_str(), + format!("/statusgroup/dataset/samplingn-{GROUP_ID}-1-0.parquet").as_str(), ) .with_status(200) .with_body(r#"{"status": "accept", "input_flops": 1, "output_flops": 1000}"#) @@ -463,7 +463,7 @@ async fn test_group_e2e_accept() -> Result<(), Error> { metrics_2.contains("validator_work_keys_to_process{pool_id=\"0\",validator_id=\"0\"} 0") ); assert!(metrics_2.contains("toploc_config_name=\"Qwen/Qwen0.6\"")); - assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{}\",pool_id=\"0\",result=\"match\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"match\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -490,32 +490,32 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { let mock_storage = MockStorageProvider::new(); mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_file( - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-1.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-1.parquet"), "file2", ) .await; mock_storage .add_mapping_file( HONEST_FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-0.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-0.parquet"), ) .await; mock_storage .add_mapping_file( EXCESSIVE_FILE_SHA, - &format!("Qwen/Qwen0.6/dataset/samplingn-{}-2-0-1.parquet", GROUP_ID), + &format!("Qwen/Qwen0.6/dataset/samplingn-{GROUP_ID}-2-0-1.parquet"), ) .await; server .mock( "POST", - format!("/validategroup/dataset/samplingn-{}-2-0.parquet", GROUP_ID).as_str(), + format!("/validategroup/dataset/samplingn-{GROUP_ID}-2-0.parquet").as_str(), ) .match_body(mockito::Matcher::Json(serde_json::json!({ "file_shas": [HONEST_FILE_SHA, EXCESSIVE_FILE_SHA], @@ -529,7 +529,7 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { server .mock( "GET", - format!("/statusgroup/dataset/samplingn-{}-2-0.parquet", GROUP_ID).as_str(), + format!("/statusgroup/dataset/samplingn-{GROUP_ID}-2-0.parquet").as_str(), ) .with_status(200) .with_body(r#"{"status": "accept", "input_flops": 1, "output_flops": 2000}"#) @@ -636,12 +636,12 @@ async fn test_group_e2e_work_unit_mismatch() -> Result<(), Error> { assert_eq!(plan_3.group_trigger_tasks.len(), 0); assert_eq!(plan_3.group_status_check_tasks.len(), 0); let metrics_2 = export_metrics().unwrap(); - assert!(metrics_2.contains(&format!("validator_group_validations_total{{group_id=\"{}\",pool_id=\"0\",result=\"accept\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_validations_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"accept\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); assert!( metrics_2.contains("validator_work_keys_to_process{pool_id=\"0\",validator_id=\"0\"} 0") ); assert!(metrics_2.contains("toploc_config_name=\"Qwen/Qwen0.6\"")); - assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{}\",pool_id=\"0\",result=\"mismatch\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics_2.contains(&format!("validator_group_work_units_check_total{{group_id=\"{GROUP_ID}\",pool_id=\"0\",result=\"mismatch\",toploc_config_name=\"Qwen/Qwen0.6\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -734,26 +734,26 @@ async fn test_incomplete_group_recovery() -> Result<(), Error> { mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-1.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-1.parquet"), "file2", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), ) .await; mock_storage .add_mapping_file( FILE_SHA_2, - &format!("TestModel/dataset/test-{}-2-0-1.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-1.parquet"), ) .await; @@ -800,7 +800,7 @@ async fn test_incomplete_group_recovery() -> Result<(), Error> { assert!(group.is_none(), "Group should be incomplete"); // Check that the incomplete group is being tracked - let group_key = format!("group:{}:2:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:2:0"); let is_tracked = validator .is_group_being_tracked_as_incomplete(&group_key) .await?; @@ -847,14 +847,14 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> mock_storage .add_file( - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-2-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-2-0-0.parquet"), ) .await; @@ -902,7 +902,7 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> // Manually expire the incomplete group tracking by removing it and simulating expiry // In a real test, you would wait for the actual expiry, but for testing we simulate it - let group_key = format!("group:{}:2:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:2:0"); validator.track_incomplete_group(&group_key).await?; // Process groups past grace period (this would normally find groups past deadline) @@ -936,7 +936,7 @@ async fn test_expired_incomplete_group_soft_invalidation() -> Result<(), Error> assert_eq!(key_status, Some(ValidationResult::IncompleteGroup)); let metrics = export_metrics().unwrap(); - assert!(metrics.contains(&format!("validator_work_keys_soft_invalidated_total{{group_key=\"group:{}:2:0\",pool_id=\"0\",validator_id=\"0\"}} 1", GROUP_ID))); + assert!(metrics.contains(&format!("validator_work_keys_soft_invalidated_total{{group_key=\"group:{GROUP_ID}:2:0\",pool_id=\"0\",validator_id=\"0\"}} 1"))); Ok(()) } @@ -952,14 +952,14 @@ async fn test_incomplete_group_status_tracking() -> Result<(), Error> { mock_storage .add_file( - &format!("TestModel/dataset/test-{}-3-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-3-0-0.parquet"), "file1", ) .await; mock_storage .add_mapping_file( FILE_SHA_1, - &format!("TestModel/dataset/test-{}-3-0-0.parquet", GROUP_ID), + &format!("TestModel/dataset/test-{GROUP_ID}-3-0-0.parquet"), ) .await; @@ -1006,7 +1006,7 @@ async fn test_incomplete_group_status_tracking() -> Result<(), Error> { // Manually process groups past grace period to simulate what would happen // after the grace period expires (we simulate this since we can't wait in tests) - let group_key = format!("group:{}:3:0", GROUP_ID); + let group_key = format!("group:{GROUP_ID}:3:0"); // Manually add the group to tracking and then process it validator.track_incomplete_group(&group_key).await?; diff --git a/crates/validator/src/validators/synthetic_data/toploc.rs b/crates/validator/src/validators/synthetic_data/toploc.rs index 33d9f57f..f5641533 100644 --- a/crates/validator/src/validators/synthetic_data/toploc.rs +++ b/crates/validator/src/validators/synthetic_data/toploc.rs @@ -689,8 +689,7 @@ mod tests { Some(expected_idx) => { assert!( matched, - "Expected file {} to match config {}", - test_file, expected_idx + "Expected file {test_file} to match config {expected_idx}" ); assert_eq!( matched_idx, @@ -701,7 +700,7 @@ mod tests { expected_idx ); } - None => assert!(!matched, "File {} should not match any config", test_file), + None => assert!(!matched, "File {test_file} should not match any config"), } } } diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index 18596ba5..bd41e6d1 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -8,57 +8,49 @@ workspace = true [dependencies] shared = { workspace = true } +prime-core = { workspace = true} +p2p = { workspace = true } + actix-web = { workspace = true } -bollard = "0.18.1" +alloy = { workspace = true } +anyhow = { workspace = true } +cid = { workspace = true } clap = { workspace = true } -colored = "2.0" -lazy_static = "1.4" -regex = "1.10" +chrono = { workspace = true } +directories = { workspace = true } +env_logger = { workspace = true } +futures = { workspace = true } +futures-util = { workspace = true } +hex = { workspace = true } +log = { workspace = true } +rand_v8 = { workspace = true } +reqwest = { workspace = true, features = ["blocking"] } +rust-ipfs = { workspace = true } serde = { workspace = true } +serde_json = { workspace = true } +stun = { workspace = true } tokio = { workspace = true, features = ["full", "macros"] } +tokio-util = { workspace = true, features = ["rt"] } +url = { workspace = true } uuid = { workspace = true } -validator = { version = "0.16", features = ["derive"] } + +bollard = "0.18.1" +colored = "2.0" +lazy_static = "1.4" sysinfo = "0.30" libc = "0.2" nvml-wrapper = "0.10.0" -log = { workspace = true } -env_logger = { workspace = true } -futures-core = "0.3" -futures-util = { workspace = true } -alloy = { workspace = true } -url = { workspace = true } -serde_json = { workspace = true } -reqwest = { workspace = true, features = ["blocking"] } -hex = { workspace = true } console = "0.15.10" -indicatif = "0.17.9" -bytes = "1.9.0" -anyhow = { workspace = true } thiserror = "2.0.11" -toml = { workspace = true } -ctrlc = "3.4.5" -tokio-util = { workspace = true, features = ["rt"] } -futures = { workspace = true } -chrono = { workspace = true } serial_test = "0.5.1" -directories = { workspace = true } strip-ansi-escapes = "0.2.1" -nalgebra = { workspace = true } -stun = { workspace = true } sha2 = "0.10.8" unicode-width = "0.2.0" rand = "0.9.0" tempfile = "3.14.0" tracing-loki = "0.2.6" -tracing = "0.1.41" +tracing = { workspace = true } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } -tracing-log = "0.2.0" time = "0.3.41" -iroh = { workspace = true } -rand_v8 = { workspace = true } -rand_core_v6 = { workspace = true } -dashmap = "6.1.0" tokio-stream = { version = "0.1.17", features = ["net"] } -rust-ipfs = { workspace = true } -cid = { workspace = true } homedir = "0.3" diff --git a/crates/worker/src/checks/hardware/interconnect.rs b/crates/worker/src/checks/hardware/interconnect.rs index 21725686..d87d1819 100644 --- a/crates/worker/src/checks/hardware/interconnect.rs +++ b/crates/worker/src/checks/hardware/interconnect.rs @@ -78,7 +78,7 @@ mod tests { #[tokio::test] async fn test_check_speeds() { let result = InterconnectCheck::check_speeds().await; - println!("Test Result: {:?}", result); + println!("Test Result: {result:?}"); // Verify the result is Ok and contains expected tuple structure assert!(result.is_ok()); diff --git a/crates/worker/src/checks/hardware/storage.rs b/crates/worker/src/checks/hardware/storage.rs index 9509e731..8360993b 100644 --- a/crates/worker/src/checks/hardware/storage.rs +++ b/crates/worker/src/checks/hardware/storage.rs @@ -216,7 +216,7 @@ fn test_or_create_app_directory(path: &str) -> bool { } #[cfg(not(target_os = "linux"))] -pub fn find_largest_storage() -> Option { +pub(crate) fn find_largest_storage() -> Option { None } @@ -233,7 +233,7 @@ pub(crate) fn get_available_space(path: &str) -> Option { } #[cfg(not(target_os = "linux"))] -pub fn get_available_space(_path: &str) -> Option { +pub(crate) fn get_available_space(_path: &str) -> Option { None } diff --git a/crates/worker/src/checks/hardware/storage.rs:236:1 b/crates/worker/src/checks/hardware/storage.rs:236:1 new file mode 100644 index 00000000..e69de29b diff --git a/crates/worker/src/checks/stun.rs b/crates/worker/src/checks/stun.rs index 5830b49e..734f2795 100644 --- a/crates/worker/src/checks/stun.rs +++ b/crates/worker/src/checks/stun.rs @@ -139,7 +139,7 @@ mod tests { async fn test_get_public_ip() { let stun_check = StunCheck::new(Duration::from_secs(5), 0); let public_ip = stun_check.get_public_ip().await.unwrap(); - println!("Public IP: {}", public_ip); + println!("Public IP: {public_ip}"); assert!(!public_ip.is_empty()); } } diff --git a/crates/worker/src/cli/command.rs b/crates/worker/src/cli/command.rs index 92de379e..566d63ad 100644 --- a/crates/worker/src/cli/command.rs +++ b/crates/worker/src/cli/command.rs @@ -6,26 +6,28 @@ use crate::console::Console; use crate::docker::taskbridge::TaskBridge; use crate::docker::DockerService; use crate::metrics::store::MetricsStore; -use crate::operations::compute_node::ComputeNodeOperations; use crate::operations::heartbeat::service::HeartbeatService; -use crate::operations::provider::ProviderOperations; -use crate::p2p::P2PContext; -use crate::p2p::P2PService; +use crate::operations::node_monitor::NodeMonitor; use crate::services::discovery::DiscoveryService; use crate::services::discovery_updater::DiscoveryUpdater; use crate::state::system_state::SystemState; use crate::TaskHandles; use alloy::primitives::utils::format_ether; +use alloy::primitives::Address; use alloy::primitives::U256; use alloy::signers::local::PrivateKeySigner; use alloy::signers::Signer; use clap::{Parser, Subcommand}; use log::{error, info}; +use prime_core::operations::compute_node::ComputeNodeOperations; +use prime_core::operations::provider::ProviderOperations; use shared::models::node::ComputeRequirements; use shared::models::node::Node; use shared::web3::contracts::core::builder::ContractBuilder; +use shared::web3::contracts::core::builder::Contracts; use shared::web3::contracts::structs::compute_pool::PoolStatus; use shared::web3::wallet::Wallet; +use shared::web3::wallet::WalletProvider; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -56,13 +58,17 @@ pub enum Commands { #[arg(long, default_value = "8080")] port: u16, + /// Port for libp2p service + #[arg(long, default_value = "4002")] + libp2p_port: u16, + /// External IP address for the worker to advertise #[arg(long)] external_ip: Option, /// Compute pool ID #[arg(long)] - compute_pool_id: u64, + compute_pool_id: u32, /// Dry run the command without starting the worker #[arg(long, default_value = "false")] @@ -123,7 +129,7 @@ pub enum Commands { #[arg(long, default_value = "false")] with_ipfs_upload: bool, - #[arg(long, default_value = "4002")] + #[arg(long, default_value = "5001")] ipfs_port: u16, }, Check {}, @@ -176,7 +182,7 @@ pub enum Commands { /// Compute pool ID #[arg(long)] - compute_pool_id: u64, + compute_pool_id: u32, }, } @@ -188,6 +194,7 @@ pub async fn execute_command( match command { Commands::Run { port: _, + libp2p_port, external_ip, compute_pool_id, dry_run: _, @@ -214,11 +221,19 @@ pub async fn execute_command( ); std::process::exit(1); } - let state = Arc::new(SystemState::new( + let state = match SystemState::new( state_dir_overwrite.clone(), *disable_state_storing, - Some(compute_pool_id.to_string()), - )); + *compute_pool_id, + ) { + Ok(state) => state, + Err(e) => { + error!("❌ Failed to initialize system state: {e}"); + std::process::exit(1); + } + }; + + let state = Arc::new(state); let private_key_provider = if let Some(key) = private_key_provider { Console::warning("Using private key from command line is not recommended. Consider using PRIVATE_KEY_PROVIDER environment variable instead."); @@ -280,12 +295,10 @@ pub async fn execute_command( let provider_ops_cancellation = cancellation_token.clone(); - let compute_node_state = state.clone(); let compute_node_ops = ComputeNodeOperations::new( &provider_wallet_instance, &node_wallet_instance, contracts.clone(), - compute_node_state, ); let discovery_urls = vec![discovery_url @@ -296,7 +309,7 @@ pub async fn execute_command( let discovery_state = state.clone(); let discovery_updater = DiscoveryUpdater::new(discovery_service.clone(), discovery_state.clone()); - let pool_id = U256::from(*compute_pool_id as u32); + let pool_id = U256::from(*compute_pool_id); let pool_info = loop { match contracts.compute_pool.get_pool_info(pool_id).await { @@ -338,7 +351,7 @@ pub async fn execute_command( .address() .to_string(), compute_specs: None, - compute_pool_id: *compute_pool_id as u32, + compute_pool_id: *compute_pool_id, worker_p2p_id: None, worker_p2p_addresses: None, }; @@ -515,7 +528,6 @@ pub async fn execute_command( .default_signer() .address() .to_string(), - state.get_p2p_seed(), *disable_host_network_mode, )); @@ -593,7 +605,7 @@ pub async fn execute_command( .retry_register_provider( required_stake, *funding_retry_count, - cancellation_token.clone(), + Some(cancellation_token.clone()), ) .await { @@ -696,20 +708,11 @@ pub async fn execute_command( let heartbeat = match heartbeat_service.clone() { Ok(service) => service, Err(e) => { - error!("❌ Heartbeat service is not available: {e}"); + error!("❌ Heartbeat service is not available: {e:?}"); std::process::exit(1); } }; - let p2p_context = P2PContext { - docker_service: docker_service.clone(), - heartbeat_service: heartbeat.clone(), - system_state: state.clone(), - contracts: contracts.clone(), - node_wallet: node_wallet_instance.clone(), - provider_wallet: provider_wallet_instance.clone(), - }; - let validators = match contracts.prime_network.get_validator_role().await { Ok(validators) => validators, Err(e) => { @@ -728,15 +731,19 @@ pub async fn execute_command( let mut allowed_addresses = vec![pool_info.creator, pool_info.compute_manager_key]; allowed_addresses.extend(validators); - let p2p_service = match P2PService::new( - state.worker_p2p_seed, - cancellation_token.clone(), - Some(p2p_context), + let validator_addresses = std::collections::HashSet::from_iter(allowed_addresses); + let p2p_service = match crate::p2p::Service::new( + state.get_p2p_keypair().clone(), + *libp2p_port, node_wallet_instance.clone(), - allowed_addresses, - ) - .await - { + validator_addresses, + docker_service.clone(), + heartbeat.clone(), + state.clone(), + contracts.clone(), + provider_wallet_instance.clone(), + cancellation_token.clone(), + ) { Ok(service) => service, Err(e) => { error!("❌ Failed to start P2P service: {e}"); @@ -744,23 +751,21 @@ pub async fn execute_command( } }; - if let Err(e) = p2p_service.start() { - error!("❌ Failed to start P2P listener: {e}"); - std::process::exit(1); - } - - node_config.worker_p2p_id = Some(p2p_service.node_id().to_string()); + let peer_id = p2p_service.peer_id(); + node_config.worker_p2p_id = Some(peer_id.to_string()); + let external_p2p_address = + format!("/ip4/{}/tcp/{}", node_config.ip_address, *libp2p_port); node_config.worker_p2p_addresses = Some( p2p_service - .listening_addresses() + .listen_addrs() .iter() .map(|addr| addr.to_string()) + .chain(vec![external_p2p_address]) .collect(), ); - Console::success(&format!( - "P2P service started with ID: {}", - p2p_service.node_id() - )); + tokio::task::spawn(p2p_service.run()); + + Console::success(&format!("P2P service started with ID: {peer_id}",)); let mut attempts = 0; let max_attempts = 100; @@ -814,9 +819,15 @@ pub async fn execute_command( // Start monitoring compute node status on chain provider_ops.start_monitoring(provider_ops_cancellation); - let pool_id = state.compute_pool_id.clone().unwrap_or("0".to_string()); - if let Err(err) = compute_node_ops.start_monitoring(cancellation_token.clone(), pool_id) - { + let node_monitor = NodeMonitor::new( + provider_wallet_instance.clone(), + node_wallet_instance.clone(), + contracts.clone(), + state.clone(), + ); + + let pool_id = state.get_compute_pool_id(); + if let Err(err) = node_monitor.start_monitoring(cancellation_token.clone(), pool_id) { error!("❌ Failed to start node monitoring: {err}"); std::process::exit(1); } @@ -1021,11 +1032,10 @@ pub async fn execute_command( std::process::exit(1); } }; - let state = Arc::new(SystemState::new(None, true, None)); + /* Initialize dependencies - services, contracts, operations */ - let contracts = ContractBuilder::new(provider_wallet_instance.provider()) .with_compute_registry() .with_ai_token() @@ -1035,25 +1045,25 @@ pub async fn execute_command( .build() .unwrap(); - let compute_node_ops = ComputeNodeOperations::new( - &provider_wallet_instance, - &node_wallet_instance, - contracts.clone(), - state.clone(), - ); + let provider_address = provider_wallet_instance.wallet.default_signer().address(); + let node_address = node_wallet_instance.wallet.default_signer().address(); let provider_ops = ProviderOperations::new(provider_wallet_instance.clone(), contracts.clone(), false); - let compute_node_exists = match compute_node_ops.check_compute_node_exists().await { - Ok(exists) => exists, + let compute_node_exists = match contracts + .compute_registry + .get_node(provider_address, node_address) + .await + { + Ok(_) => true, Err(e) => { Console::user_error(&format!("❌ Failed to check if compute node exists: {e}")); std::process::exit(1); } }; - let pool_id = U256::from(*compute_pool_id as u32); + let pool_id = U256::from(*compute_pool_id); if compute_node_exists { match contracts @@ -1073,7 +1083,7 @@ pub async fn execute_command( std::process::exit(1); } } - match compute_node_ops.remove_compute_node().await { + match remove_compute_node(contracts, provider_address, node_address).await { Ok(_removed_node) => { Console::success("Compute node removed"); match provider_ops.reclaim_stake(U256::from(0)).await { @@ -1099,3 +1109,17 @@ pub async fn execute_command( } } } + +async fn remove_compute_node( + contracts: Contracts, + provider_address: Address, + node_address: Address, +) -> Result> { + Console::title("🔄 Removing compute node"); + let remove_node_tx = contracts + .prime_network + .remove_compute_node(provider_address, node_address) + .await?; + Console::success(&format!("Remove node tx: {remove_node_tx:?}")); + Ok(true) +} diff --git a/crates/worker/src/docker/service.rs b/crates/worker/src/docker/service.rs index 63425e2d..da15b88e 100644 --- a/crates/worker/src/docker/service.rs +++ b/crates/worker/src/docker/service.rs @@ -24,7 +24,6 @@ pub(crate) struct DockerService { system_memory_mb: Option, task_bridge_socket_path: String, node_address: String, - p2p_seed: Option, } const TASK_PREFIX: &str = "prime-task"; @@ -39,7 +38,6 @@ impl DockerService { task_bridge_socket_path: String, storage_path: String, node_address: String, - p2p_seed: Option, disable_host_network_mode: bool, ) -> Self { let docker_manager = @@ -52,7 +50,6 @@ impl DockerService { system_memory_mb, task_bridge_socket_path, node_address, - p2p_seed, } } @@ -177,7 +174,6 @@ impl DockerService { let system_memory_mb = self.system_memory_mb; let task_bridge_socket_path = self.task_bridge_socket_path.clone(); let node_address = self.node_address.clone(); - let p2p_seed = self.p2p_seed; let handle = tokio::spawn(async move { let Some(payload) = state_clone.get_current_task().await else { return; @@ -185,11 +181,7 @@ impl DockerService { let cmd = match payload.cmd { Some(cmd_vec) => { cmd_vec.into_iter().map(|arg| { - let mut processed_arg = arg.replace("${SOCKET_PATH}", &task_bridge_socket_path); - if let Some(seed) = p2p_seed { - processed_arg = processed_arg.replace("${WORKER_P2P_SEED}", &seed.to_string()); - } - processed_arg + arg.replace("${SOCKET_PATH}", &task_bridge_socket_path) }).collect() } None => vec!["sleep".to_string(), "infinity".to_string()], @@ -199,10 +191,7 @@ impl DockerService { if let Some(env) = &payload.env_vars { // Clone env vars and replace ${SOCKET_PATH} in values for (key, value) in env.iter() { - let mut processed_value = value.replace("${SOCKET_PATH}", &task_bridge_socket_path); - if let Some(seed) = p2p_seed { - processed_value = processed_value.replace("${WORKER_P2P_SEED}", &seed.to_string()); - } + let processed_value = value.replace("${SOCKET_PATH}", &task_bridge_socket_path); env_vars.insert(key.clone(), processed_value); } } @@ -432,7 +421,6 @@ mod tests { "/tmp/com.prime.miner/metrics.sock".to_string(), "/tmp/test-storage".to_string(), Address::ZERO.to_string(), - None, false, ); let task = Task { @@ -481,7 +469,6 @@ mod tests { test_socket_path.to_string(), "/tmp/test-storage".to_string(), Address::ZERO.to_string(), - Some(12345), // p2p_seed for testing false, ); diff --git a/crates/worker/src/docker/taskbridge/bridge.rs b/crates/worker/src/docker/taskbridge/bridge.rs index 65a28f76..594bc62d 100644 --- a/crates/worker/src/docker/taskbridge/bridge.rs +++ b/crates/worker/src/docker/taskbridge/bridge.rs @@ -473,7 +473,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0).unwrap()); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), @@ -506,7 +506,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0).unwrap()); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), @@ -541,7 +541,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0).unwrap()); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), @@ -565,7 +565,7 @@ mod tests { "test_label2": 20.0, }); let sample_metric = serde_json::to_string(&data)?; - debug!("Sending {:?}", sample_metric); + debug!("Sending {sample_metric:?}"); let msg = format!("{}{}", sample_metric, "\n"); stream.write_all(msg.as_bytes()).await?; stream.flush().await?; @@ -590,7 +590,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0).unwrap()); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), @@ -616,7 +616,7 @@ mod tests { "output/input_flops": 2500.0, }); let sample_metric = serde_json::to_string(&json)?; - debug!("Sending {:?}", sample_metric); + debug!("Sending {sample_metric:?}"); let msg = format!("{}{}", sample_metric, "\n"); stream.write_all(msg.as_bytes()).await?; stream.flush().await?; @@ -626,8 +626,7 @@ mod tests { let all_metrics = metrics_store.get_all_metrics().await; assert!( all_metrics.is_empty(), - "Expected metrics to be empty but found: {:?}", - all_metrics + "Expected metrics to be empty but found: {all_metrics:?}" ); bridge_handle.abort(); @@ -639,7 +638,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0).unwrap()); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), diff --git a/crates/worker/src/operations/heartbeat/service.rs b/crates/worker/src/operations/heartbeat/service.rs index 0d77d783..289e86af 100644 --- a/crates/worker/src/operations/heartbeat/service.rs +++ b/crates/worker/src/operations/heartbeat/service.rs @@ -24,7 +24,6 @@ pub(crate) struct HeartbeatService { docker_service: Arc, metrics_store: Arc, } - #[derive(Debug, Clone, thiserror::Error)] pub(crate) enum HeartbeatError { #[error("HTTP request failed")] @@ -32,6 +31,7 @@ pub(crate) enum HeartbeatError { #[error("Service initialization failed")] InitFailed, } + impl HeartbeatService { #[allow(clippy::too_many_arguments)] pub(crate) fn new( @@ -143,7 +143,7 @@ async fn send_heartbeat( wallet: Wallet, docker_service: Arc, metrics_store: Arc, - p2p_id: Option, + p2p_id: p2p::PeerId, ) -> Result { if endpoint.is_none() { return Err(HeartbeatError::RequestFailed); @@ -176,7 +176,7 @@ async fn send_heartbeat( .to_string(), ), timestamp: Some(ts), - p2p_id, + p2p_id: Some(p2p_id.to_string()), // TODO: this should always be `Some` task_details, } } else { @@ -188,7 +188,7 @@ async fn send_heartbeat( .to_string(), ), timestamp: Some(ts), - p2p_id, + p2p_id: Some(p2p_id.to_string()), // TODO: this should always be `Some` ..Default::default() } }; diff --git a/crates/worker/src/operations/mod.rs b/crates/worker/src/operations/mod.rs index 193b64ae..d684160a 100644 --- a/crates/worker/src/operations/mod.rs +++ b/crates/worker/src/operations/mod.rs @@ -1,3 +1,2 @@ -pub(crate) mod compute_node; pub(crate) mod heartbeat; -pub(crate) mod provider; +pub(crate) mod node_monitor; diff --git a/crates/worker/src/operations/compute_node.rs b/crates/worker/src/operations/node_monitor.rs similarity index 51% rename from crates/worker/src/operations/compute_node.rs rename to crates/worker/src/operations/node_monitor.rs index 39b18c29..af33d450 100644 --- a/crates/worker/src/operations/compute_node.rs +++ b/crates/worker/src/operations/node_monitor.rs @@ -1,5 +1,5 @@ -use crate::{console::Console, state::system_state::SystemState}; -use alloy::{primitives::utils::keccak256 as keccak, primitives::U256, signers::Signer}; +use crate::state::system_state::SystemState; +use alloy::primitives::U256; use anyhow::Result; use shared::web3::wallet::Wallet; use shared::web3::{contracts::core::builder::Contracts, wallet::WalletProvider}; @@ -7,17 +7,17 @@ use std::sync::Arc; use tokio::time::{sleep, Duration}; use tokio_util::sync::CancellationToken; -pub(crate) struct ComputeNodeOperations<'c> { - provider_wallet: &'c Wallet, - node_wallet: &'c Wallet, +pub(crate) struct NodeMonitor { + provider_wallet: Wallet, + node_wallet: Wallet, contracts: Contracts, system_state: Arc, } -impl<'c> ComputeNodeOperations<'c> { +impl NodeMonitor { pub(crate) fn new( - provider_wallet: &'c Wallet, - node_wallet: &'c Wallet, + provider_wallet: Wallet, + node_wallet: Wallet, contracts: Contracts, system_state: Arc, ) -> Self { @@ -32,7 +32,7 @@ impl<'c> ComputeNodeOperations<'c> { pub(crate) fn start_monitoring( &self, cancellation_token: CancellationToken, - pool_id: String, + pool_id: u32, ) -> Result<()> { let provider_address = self.provider_wallet.wallet.default_signer().address(); let node_address = self.node_wallet.wallet.default_signer().address(); @@ -43,11 +43,12 @@ impl<'c> ComputeNodeOperations<'c> { let mut last_claimable = None; let mut last_locked = None; let mut first_check = true; + tokio::spawn(async move { loop { tokio::select! { _ = cancellation_token.cancelled() => { - Console::info("Monitor", "Shutting down node status monitor..."); + log::info!("Shutting down node status monitor..."); break; } _ = async { @@ -55,16 +56,15 @@ impl<'c> ComputeNodeOperations<'c> { Ok((active, validated)) => { if first_check || active != last_active { if !first_check { - Console::info("🔄 Chain Sync - Pool membership changed", &format!("From {last_active} to {active}" - )); + log::info!("🔄 Chain Sync - Pool membership changed: From {last_active} to {active}"); } else { - Console::info("🔄 Chain Sync - Node pool membership", &format!("{active}")); + log::info!("🔄 Chain Sync - Node pool membership: {active}"); } last_active = active; } let is_running = system_state.is_running().await; if !active && is_running { - Console::warning("Node is not longer in pool, shutting down heartbeat..."); + log::warn!("Node is not longer in pool, shutting down heartbeat..."); if let Err(e) = system_state.set_running(false, None).await { log::error!("Failed to set running to false: {e:?}"); } @@ -72,18 +72,16 @@ impl<'c> ComputeNodeOperations<'c> { if first_check || validated != last_validated { if !first_check { - Console::info("🔄 Chain Sync - Validation changed", &format!("From {last_validated} to {validated}" - )); + log::info!("🔄 Chain Sync - Validation changed: From {last_validated} to {validated}"); } else { - Console::info("🔄 Chain Sync - Node validation", &format!("{validated}")); + log::info!("🔄 Chain Sync - Node validation: {validated}"); } last_validated = validated; } // Check rewards for the current compute pool - if let Ok(pool_id_u32) = pool_id.parse::() { match contracts.compute_pool.calculate_node_rewards( - U256::from(pool_id_u32), + U256::from(pool_id), node_address, ).await { Ok((claimable, locked)) => { @@ -92,13 +90,13 @@ impl<'c> ComputeNodeOperations<'c> { last_locked = Some(locked); let claimable_formatted = claimable.to_string().parse::().unwrap_or(0.0) / 10f64.powf(18.0); let locked_formatted = locked.to_string().parse::().unwrap_or(0.0) / 10f64.powf(18.0); - Console::info("Rewards", &format!("{claimable_formatted} claimable, {locked_formatted} locked")); + log::info!("Rewards: {claimable_formatted} claimable, {locked_formatted} locked"); } } Err(e) => { - log::debug!("Failed to check rewards for pool {pool_id_u32}: {e}"); + log::debug!("Failed to check rewards for pool {pool_id}: {e}"); } - } + } first_check = false; @@ -114,74 +112,4 @@ impl<'c> ComputeNodeOperations<'c> { }); Ok(()) } - - pub(crate) async fn check_compute_node_exists( - &self, - ) -> Result> { - let compute_node = self - .contracts - .compute_registry - .get_node( - self.provider_wallet.wallet.default_signer().address(), - self.node_wallet.wallet.default_signer().address(), - ) - .await; - - match compute_node { - Ok(_) => Ok(true), - Err(_) => Ok(false), - } - } - - // Returns true if the compute node was added, false if it already exists - pub(crate) async fn add_compute_node( - &self, - compute_units: U256, - ) -> Result> { - Console::title("🔄 Adding compute node"); - - if self.check_compute_node_exists().await? { - return Ok(false); - } - - Console::progress("Adding compute node"); - let provider_address = self.provider_wallet.wallet.default_signer().address(); - let node_address = self.node_wallet.wallet.default_signer().address(); - let digest = keccak([provider_address.as_slice(), node_address.as_slice()].concat()); - - let signature = self - .node_wallet - .signer - .sign_message(digest.as_slice()) - .await? - .as_bytes(); - - // Create the signature bytes - let add_node_tx = self - .contracts - .prime_network - .add_compute_node(node_address, compute_units, signature.to_vec()) - .await?; - Console::success(&format!("Add node tx: {add_node_tx:?}")); - Ok(true) - } - - pub(crate) async fn remove_compute_node(&self) -> Result> { - Console::title("🔄 Removing compute node"); - - if !self.check_compute_node_exists().await? { - return Ok(false); - } - - Console::progress("Removing compute node"); - let provider_address = self.provider_wallet.wallet.default_signer().address(); - let node_address = self.node_wallet.wallet.default_signer().address(); - let remove_node_tx = self - .contracts - .prime_network - .remove_compute_node(provider_address, node_address) - .await?; - Console::success(&format!("Remove node tx: {remove_node_tx:?}")); - Ok(true) - } } diff --git a/crates/worker/src/p2p/mod.rs b/crates/worker/src/p2p/mod.rs index 9393f985..94fe10a3 100644 --- a/crates/worker/src/p2p/mod.rs +++ b/crates/worker/src/p2p/mod.rs @@ -1,4 +1,497 @@ -pub(crate) mod service; +use anyhow::Context as _; +use anyhow::Result; +use futures::stream::FuturesUnordered; +use p2p::InviteRequestUrl; +use p2p::Node; +use p2p::NodeBuilder; +use p2p::PeerId; +use p2p::Response; +use p2p::{IncomingMessage, Libp2pIncomingMessage, OutgoingMessage}; +use shared::web3::contracts::core::builder::Contracts; +use shared::web3::wallet::Wallet; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; +use std::time::SystemTime; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; -pub(crate) use service::P2PContext; -pub(crate) use service::P2PService; +use crate::docker::DockerService; +use crate::operations::heartbeat::service::HeartbeatService; +use crate::state::system_state::SystemState; +use shared::web3::wallet::WalletProvider; + +pub(crate) struct Service { + node: Node, + incoming_messages: Receiver, + cancellation_token: CancellationToken, + context: Context, +} + +impl Service { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + keypair: p2p::Keypair, + port: u16, + wallet: Wallet, + validator_addresses: HashSet, + docker_service: Arc, + heartbeat_service: Arc, + system_state: Arc, + contracts: Contracts, + provider_wallet: Wallet, + cancellation_token: CancellationToken, + ) -> Result { + let (node, incoming_messages, outgoing_messages) = + build_p2p_node(keypair, port, cancellation_token.clone()) + .context("failed to build p2p node")?; + Ok(Self { + node, + incoming_messages, + cancellation_token, + context: Context::new( + wallet, + outgoing_messages, + validator_addresses, + docker_service, + heartbeat_service, + system_state, + contracts, + provider_wallet, + ), + }) + } + + pub(crate) fn peer_id(&self) -> PeerId { + self.node.peer_id() + } + + pub(crate) fn listen_addrs(&self) -> &[p2p::Multiaddr] { + self.node.listen_addrs() + } + + pub(crate) async fn run(self) { + use futures::StreamExt as _; + + let Self { + node, + mut incoming_messages, + cancellation_token, + context, + } = self; + + tokio::task::spawn(node.run()); + + let mut message_handlers = FuturesUnordered::new(); + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break; + } + Some(message) = incoming_messages.recv() => { + let context = context.clone(); + let handle = tokio::task::spawn( + handle_incoming_message(message, context) + ); + message_handlers.push(handle); + } + Some(res) = message_handlers.next() => { + if let Err(e) = res { + tracing::error!("failed to handle incoming message: {e}"); + } + } + } + } + } +} + +fn build_p2p_node( + keypair: p2p::Keypair, + port: u16, + cancellation_token: CancellationToken, +) -> Result<(Node, Receiver, Sender)> { + let (node, incoming_message_rx, outgoing_message_tx) = NodeBuilder::new() + .with_keypair(keypair) + .with_port(port) + .with_authentication() + .with_hardware_challenge() + .with_invite() + .with_get_task_logs() + .with_restart() + .with_cancellation_token(cancellation_token) + .try_build() + .context("failed to build p2p node")?; + Ok((node, incoming_message_rx, outgoing_message_tx)) +} + +#[derive(Clone)] +struct Context { + authorized_peers: Arc>>, + wallet: Wallet, + validator_addresses: Arc>, + + // for validator authentication requests + ongoing_auth_challenges: Arc>>, // use request_id? + nonce_cache: Arc>>, + outgoing_messages: Sender, + + // for get_task_logs and restart requests + docker_service: Arc, + + // for invite requests + heartbeat_service: Arc, + system_state: Arc, + contracts: Contracts, + provider_wallet: Wallet, +} + +impl Context { + #[allow(clippy::too_many_arguments)] + fn new( + wallet: Wallet, + outgoing_messages: Sender, + validator_addresses: HashSet, + docker_service: Arc, + heartbeat_service: Arc, + system_state: Arc, + contracts: Contracts, + provider_wallet: Wallet, + ) -> Self { + Self { + authorized_peers: Arc::new(RwLock::new(HashSet::new())), + ongoing_auth_challenges: Arc::new(RwLock::new(HashMap::new())), + nonce_cache: Arc::new(RwLock::new(HashMap::new())), + wallet, + outgoing_messages, + validator_addresses: Arc::new(validator_addresses), + docker_service, + heartbeat_service, + system_state, + contracts, + provider_wallet, + } + } +} + +async fn handle_incoming_message(message: IncomingMessage, context: Context) -> Result<()> { + match message.message { + Libp2pIncomingMessage::Request { + request_id: _, + request, + channel, + } => { + tracing::debug!("received incoming request {request:?}"); + handle_incoming_request(message.peer, request, channel, context).await?; + } + Libp2pIncomingMessage::Response { + request_id: _, + response, + } => { + tracing::debug!("received incoming response {response:?}"); + handle_incoming_response(response); + } + } + Ok(()) +} + +async fn handle_incoming_request( + from: PeerId, + request: p2p::Request, + channel: p2p::ResponseChannel, + context: Context, +) -> Result<()> { + let resp = match request { + p2p::Request::Authentication(req) => { + tracing::debug!("handling ValidatorAuthentication request"); + match req { + p2p::AuthenticationRequest::Initiation(req) => { + handle_validator_authentication_initiation_request(from, req, &context) + .await + .context("failed to handle ValidatorAuthenticationInitiationRequest")? + } + p2p::AuthenticationRequest::Solution(req) => { + match handle_validator_authentication_solution_request(from, req, &context) + .await + { + Ok(()) => p2p::AuthenticationSolutionResponse::Granted.into(), + Err(e) => { + tracing::error!( + "failed to handle ValidatorAuthenticationSolutionRequest: {e:?}" + ); + p2p::AuthenticationSolutionResponse::Rejected.into() + } + } + } + } + } + p2p::Request::HardwareChallenge(req) => { + tracing::debug!("handling HardwareChallenge request"); + handle_hardware_challenge_request(from, req, &context) + .await + .context("failed to handle HardwareChallenge request")? + } + p2p::Request::Invite(req) => { + tracing::debug!("handling Invite request"); + match handle_invite_request(from, req, &context).await { + Ok(()) => p2p::InviteResponse::Ok.into(), + Err(e) => p2p::InviteResponse::Error(e.to_string()).into(), + } + } + p2p::Request::GetTaskLogs => { + tracing::debug!("handling GetTaskLogs request"); + handle_get_task_logs_request(from, &context).await + } + p2p::Request::RestartTask => { + tracing::debug!("handling Restart request"); + handle_restart_request(from, &context).await + } + p2p::Request::General(_) => { + todo!() + } + }; + + let outgoing_message = resp.into_outgoing_message(channel); + context + .outgoing_messages + .send(outgoing_message) + .await + .context("failed to send ValidatorAuthentication response")?; + + Ok(()) +} + +async fn handle_validator_authentication_initiation_request( + from: PeerId, + req: p2p::AuthenticationInitiationRequest, + context: &Context, +) -> Result { + use rand_v8::Rng as _; + use shared::security::request_signer::sign_message; + + // generate a fresh cryptographically secure challenge message for this auth attempt + let challenge_bytes: [u8; 32] = rand_v8::rngs::OsRng.gen(); + let challenge_message = hex::encode(challenge_bytes); + let signature = sign_message(&req.message, &context.wallet) + .await + .map_err(|e| anyhow::anyhow!("failed to sign message: {e:?}"))?; + + // store the challenge message in nonce cache to prevent replay + let mut nonce_cache = context.nonce_cache.write().await; + nonce_cache.insert(challenge_message.clone(), SystemTime::now()); + + // store the current challenge for this peer + let mut ongoing_auth_challenges = context.ongoing_auth_challenges.write().await; + ongoing_auth_challenges.insert(from, challenge_message.clone()); + + Ok(p2p::AuthenticationInitiationResponse { + message: challenge_message, + signature, + } + .into()) +} + +async fn handle_validator_authentication_solution_request( + from: PeerId, + req: p2p::AuthenticationSolutionRequest, + context: &Context, +) -> Result<()> { + use std::str::FromStr as _; + + let mut ongoing_auth_challenges = context.ongoing_auth_challenges.write().await; + let challenge_message = ongoing_auth_challenges + .remove(&from) + .ok_or_else(|| anyhow::anyhow!("no ongoing authentication challenge for peer {from}"))?; + + let mut nonce_cache = context.nonce_cache.write().await; + if nonce_cache.remove(&challenge_message).is_none() { + anyhow::bail!("challenge message {challenge_message} not found in nonce cache"); + } + + let Ok(signature) = alloy::primitives::Signature::from_str(&req.signature) else { + anyhow::bail!("failed to parse signature from message"); + }; + + let Ok(recovered_address) = signature.recover_address_from_msg(challenge_message) else { + anyhow::bail!("failed to recover address from signature and message"); + }; + + if !context.validator_addresses.contains(&recovered_address) { + anyhow::bail!("recovered address {recovered_address} is not in the list of authorized validator addresses"); + } + + let mut authorized_peers = context.authorized_peers.write().await; + authorized_peers.insert(from); + Ok(()) +} + +async fn handle_hardware_challenge_request( + from: PeerId, + request: p2p::HardwareChallengeRequest, + context: &Context, +) -> Result { + let authorized_peers = context.authorized_peers.read().await; + if !authorized_peers.contains(&from) { + // TODO: error response variant? + anyhow::bail!("unauthorized peer {from} attempted to access HardwareChallenge request"); + } + + let challenge_response = p2p::calc_matrix(&request.challenge); + let response = p2p::HardwareChallengeResponse { + response: challenge_response, + timestamp: SystemTime::now(), + }; + Ok(response.into()) +} + +async fn handle_get_task_logs_request(from: PeerId, context: &Context) -> Response { + let authorized_peers = context.authorized_peers.read().await; + if !authorized_peers.contains(&from) { + return p2p::GetTaskLogsResponse::Error("unauthorized".to_string()).into(); + } + + match context.docker_service.get_logs().await { + Ok(logs) => p2p::GetTaskLogsResponse::Ok(logs).into(), + Err(e) => p2p::GetTaskLogsResponse::Error(format!("failed to get task logs: {e:?}")).into(), + } +} + +async fn handle_restart_request(from: PeerId, context: &Context) -> Response { + let authorized_peers = context.authorized_peers.read().await; + if !authorized_peers.contains(&from) { + return p2p::RestartTaskResponse::Error("unauthorized".to_string()).into(); + } + + match context.docker_service.restart_task().await { + Ok(()) => p2p::RestartTaskResponse::Ok.into(), + Err(e) => p2p::RestartTaskResponse::Error(format!("failed to restart task: {e:?}")).into(), + } +} + +fn handle_incoming_response(response: p2p::Response) { + // critical developer error if any of these happen, could panic here + match response { + p2p::Response::Authentication(_) => { + tracing::error!("worker should never receive ValidatorAuthentication responses"); + } + p2p::Response::HardwareChallenge(_) => { + tracing::error!("worker should never receive HardwareChallenge responses"); + } + p2p::Response::Invite(_) => { + tracing::error!("worker should never receive Invite responses"); + } + p2p::Response::GetTaskLogs(_) => { + tracing::error!("worker should never receive GetTaskLogs responses"); + } + p2p::Response::RestartTask(_) => { + tracing::error!("worker should never receive Restart responses"); + } + p2p::Response::General(_) => { + todo!() + } + } +} + +async fn handle_invite_request( + from: PeerId, + req: p2p::InviteRequest, + context: &Context, +) -> Result<()> { + use crate::console::Console; + use shared::web3::contracts::helpers::utils::retry_call; + use shared::web3::contracts::structs::compute_pool::PoolStatus; + + let authorized_peers = context.authorized_peers.read().await; + if !authorized_peers.contains(&from) { + return Err(anyhow::anyhow!( + "unauthorized peer {from} attempted to send invite" + )); + } + + if context.system_state.is_running().await { + anyhow::bail!("heartbeat is currently running and in a compute pool"); + } + + if req.pool_id != context.system_state.get_compute_pool_id() { + anyhow::bail!( + "pool ID mismatch: expected {}, got {}", + context.system_state.get_compute_pool_id(), + req.pool_id + ); + } + + let invite_bytes = hex::decode(&req.invite).context("failed to decode invite hex")?; + + if invite_bytes.len() < 65 { + anyhow::bail!("invite data is too short, expected at least 65 bytes"); + } + + let contracts = &context.contracts; + let pool_id = alloy::primitives::U256::from(req.pool_id); + + let bytes_array: [u8; 65] = match invite_bytes[..65].try_into() { + Ok(array) => array, + Err(_) => { + anyhow::bail!("failed to convert invite bytes to 65 byte array"); + } + }; + + let provider_address = context.provider_wallet.wallet.default_signer().address(); + + let pool_info = match contracts.compute_pool.get_pool_info(pool_id).await { + Ok(info) => info, + Err(err) => { + anyhow::bail!("failed to get pool info: {err:?}"); + } + }; + + if let PoolStatus::PENDING = pool_info.status { + anyhow::bail!("invalid invite; pool is pending"); + } + + let node_address = vec![context.wallet.wallet.default_signer().address()]; + let signatures = vec![alloy::primitives::FixedBytes::from(&bytes_array)]; + let call = contracts + .compute_pool + .build_join_compute_pool_call( + pool_id, + provider_address, + node_address, + vec![req.nonce], + vec![req.expiration], + signatures, + ) + .map_err(|e| anyhow::anyhow!("failed to build join compute pool call: {e:?}"))?; + + let provider = &context.provider_wallet.provider; + match retry_call(call, 3, provider.clone(), None).await { + Ok(result) => { + Console::section("WORKER JOINED COMPUTE POOL"); + Console::success(&format!( + "Successfully registered on chain with tx: {result}" + )); + Console::info( + "Status", + "Worker is now part of the compute pool and ready to receive tasks", + ); + } + Err(err) => { + anyhow::bail!("failed to join compute pool: {err:?}"); + } + } + + let heartbeat_endpoint = match req.url { + InviteRequestUrl::MasterIpPort(ip, port) => { + format!("http://{ip}:{port}/heartbeat") + } + InviteRequestUrl::MasterUrl(url) => format!("{url}/heartbeat"), + }; + + context + .heartbeat_service + .start(heartbeat_endpoint) + .await + .context("failed to start heartbeat service")?; + Ok(()) +} diff --git a/crates/worker/src/p2p/service.rs b/crates/worker/src/p2p/service.rs deleted file mode 100644 index 51a68405..00000000 --- a/crates/worker/src/p2p/service.rs +++ /dev/null @@ -1,736 +0,0 @@ -use crate::console::Console; -use crate::docker::DockerService; -use crate::operations::heartbeat::service::HeartbeatService; -use crate::state::system_state::SystemState; -use alloy::primitives::{Address, FixedBytes, U256}; -use anyhow::Result; -use dashmap::DashMap; -use iroh::endpoint::Incoming; -use iroh::{Endpoint, RelayMode, SecretKey}; -use lazy_static::lazy_static; -use log::{debug, error, info, warn}; -use rand_v8::Rng; -use shared::models::challenge::calc_matrix; -use shared::models::invite::InviteRequest; -use shared::p2p::messages::MAX_MESSAGE_SIZE; -use shared::p2p::messages::{P2PMessage, P2PRequest, P2PResponse}; -use shared::p2p::protocol::PRIME_P2P_PROTOCOL; -use shared::security::request_signer::sign_message; -use shared::web3::contracts::core::builder::Contracts; -use shared::web3::contracts::helpers::utils::retry_call; -use shared::web3::contracts::structs::compute_pool::PoolStatus; -use shared::web3::wallet::{Wallet, WalletProvider}; -use std::str::FromStr; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; -use tokio_util::sync::CancellationToken; - -lazy_static! { - static ref NONCE_CACHE: DashMap = DashMap::new(); -} - -#[derive(Clone)] -pub(crate) struct P2PContext { - pub docker_service: Arc, - pub heartbeat_service: Arc, - pub system_state: Arc, - pub contracts: Contracts, - pub node_wallet: Wallet, - pub provider_wallet: Wallet, -} - -#[derive(Clone)] -pub(crate) struct P2PService { - endpoint: Endpoint, - secret_key: SecretKey, - node_id: String, - listening_addrs: Vec, - cancellation_token: CancellationToken, - context: Option, - allowed_addresses: Vec
, - wallet: Wallet, -} - -enum EndpointLoopResult { - Shutdown, - EndpointClosed, -} - -impl P2PService { - /// Create a new P2P service with a unique worker identity - pub(crate) async fn new( - worker_p2p_seed: Option, - cancellation_token: CancellationToken, - context: Option, - wallet: Wallet, - allowed_addresses: Vec
, - ) -> Result { - // Generate or derive the secret key for this worker - let secret_key = if let Some(seed) = worker_p2p_seed { - // Derive from seed for deterministic identity - let mut seed_bytes = [0u8; 32]; - seed_bytes[..8].copy_from_slice(&seed.to_le_bytes()); - SecretKey::from_bytes(&seed_bytes) - } else { - let mut rng = rand_v8::thread_rng(); - SecretKey::generate(&mut rng) - }; - - let node_id = secret_key.public().to_string(); - info!("Starting P2P service with node ID: {node_id}"); - - // Create the endpoint - let endpoint = Endpoint::builder() - .secret_key(secret_key.clone()) - .alpns(vec![PRIME_P2P_PROTOCOL.to_vec()]) - .discovery_n0() - .relay_mode(RelayMode::Default) - .bind() - .await?; - - // Get listening addresses - let node_addr = endpoint.node_addr().await?; - let listening_addrs = node_addr - .direct_addresses - .iter() - .map(|addr| addr.to_string()) - .collect::>(); - - info!("P2P service listening on: {listening_addrs:?}"); - - Ok(Self { - endpoint, - secret_key, - node_id, - listening_addrs, - cancellation_token, - context, - allowed_addresses, - wallet, - }) - } - - /// Get the P2P node ID - pub(crate) fn node_id(&self) -> &str { - &self.node_id - } - - /// Get the listening addresses - pub(crate) fn listening_addresses(&self) -> &[String] { - &self.listening_addrs - } - - /// Recreate the endpoint with the same identity - async fn recreate_endpoint(&self) -> Result { - info!("Recreating P2P endpoint with node ID: {}", self.node_id); - - let endpoint = Endpoint::builder() - .secret_key(self.secret_key.clone()) - .alpns(vec![PRIME_P2P_PROTOCOL.to_vec()]) - .discovery_n0() - .relay_mode(RelayMode::Default) - .bind() - .await?; - - let node_addr = endpoint.node_addr().await?; - let listening_addrs = node_addr - .direct_addresses - .iter() - .map(|addr| addr.to_string()) - .collect::>(); - - info!("P2P endpoint recreated, listening on: {listening_addrs:?}"); - Ok(endpoint) - } - /// Start accepting incoming connections with automatic recovery - pub(crate) fn start(&self) -> Result<()> { - let service = Arc::new(self.clone()); - let cancellation_token = self.cancellation_token.clone(); - - tokio::spawn(async move { - service.run_with_recovery(cancellation_token).await; - }); - - Ok(()) - } - - /// Run the P2P service with automatic endpoint recovery - async fn run_with_recovery(&self, cancellation_token: CancellationToken) { - let mut endpoint = self.endpoint.clone(); - let mut retry_delay = Duration::from_secs(1); - const MAX_RETRY_DELAY: Duration = Duration::from_secs(60); - - loop { - tokio::select! { - _ = cancellation_token.cancelled() => { - info!("P2P service shutting down"); - break; - } - result = self.run_endpoint_loop(&endpoint, &cancellation_token) => { - match result { - EndpointLoopResult::Shutdown => break, - EndpointLoopResult::EndpointClosed => { - warn!("P2P endpoint closed, attempting recovery in {retry_delay:?}"); - - tokio::select! { - _ = cancellation_token.cancelled() => break, - _ = tokio::time::sleep(retry_delay) => {} - } - - match self.recreate_endpoint().await { - Ok(new_endpoint) => { - info!("P2P endpoint successfully recovered"); - endpoint = new_endpoint; - retry_delay = Duration::from_secs(1); - } - Err(e) => { - error!("Failed to recreate P2P endpoint: {e}"); - retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY); - } - } - } - } - } - } - } - } - - /// Run the main endpoint acceptance loop - async fn run_endpoint_loop( - &self, - endpoint: &Endpoint, - cancellation_token: &CancellationToken, - ) -> EndpointLoopResult { - let context = self.context.clone(); - let allowed_addresses = self.allowed_addresses.clone(); - let wallet = self.wallet.clone(); - - loop { - tokio::select! { - _ = cancellation_token.cancelled() => { - return EndpointLoopResult::Shutdown; - } - incoming = endpoint.accept() => { - if let Some(incoming) = incoming { - tokio::spawn(Self::handle_connection(incoming, context.clone(), allowed_addresses.clone(), wallet.clone())); - } else { - return EndpointLoopResult::EndpointClosed; - } - } - } - } - } - - /// Handle an incoming connection - async fn handle_connection( - incoming: Incoming, - context: Option, - allowed_addresses: Vec
, - wallet: Wallet, - ) { - match incoming.await { - Ok(connection) => { - match connection.accept_bi().await { - Ok((send, recv)) => { - if let Err(e) = - Self::handle_stream(send, recv, context, allowed_addresses, wallet) - .await - { - error!("Error handling stream: {e}"); - } - // Wait a bit before closing to ensure client has processed response - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - Err(e) => { - error!("Failed to accept bi-stream: {e}"); - connection.close(1u32.into(), b"stream error"); - } - } - } - Err(e) => { - // Only log as debug for protocol mismatches, which are expected - if e.to_string() - .contains("peer doesn't support any known protocol") - { - debug!("Connection attempt with unsupported protocol: {e}"); - } else { - error!("Failed to accept connection: {e}"); - } - } - } - } - - /// Read a message from the stream - async fn read_message(recv: &mut iroh::endpoint::RecvStream) -> Result { - // Read message length - let mut msg_len_bytes = [0u8; 4]; - match recv.read_exact(&mut msg_len_bytes).await { - Ok(_) => {} - Err(e) => { - debug!("Stream read ended: {e}"); - return Err(anyhow::anyhow!("Stream closed")); - } - } - let msg_len = u32::from_be_bytes(msg_len_bytes) as usize; - - // Enforce maximum message size - if msg_len > MAX_MESSAGE_SIZE { - error!("Message size {msg_len} exceeds maximum allowed size {MAX_MESSAGE_SIZE}"); - return Err(anyhow::anyhow!("Message too large")); - } - - let mut msg_bytes = vec![0u8; msg_len]; - recv.read_exact(&mut msg_bytes).await?; - - let request: P2PRequest = serde_json::from_slice(&msg_bytes) - .map_err(|e| anyhow::anyhow!("Failed to deserialize P2P request: {}", e))?; - - debug!("Received P2P request: {request:?}"); - Ok(request) - } - - async fn write_response( - send: &mut iroh::endpoint::SendStream, - response: P2PResponse, - ) -> Result<()> { - let response_bytes = serde_json::to_vec(&response)?; - - // Check response size before sending - if response_bytes.len() > MAX_MESSAGE_SIZE { - error!( - "Response size {} exceeds maximum allowed size {}", - response_bytes.len(), - MAX_MESSAGE_SIZE - ); - return Err(anyhow::anyhow!("Response too large")); - } - - send.write_all(&(response_bytes.len() as u32).to_be_bytes()) - .await?; - send.write_all(&response_bytes).await?; - Ok(()) - } - - /// Handle a bidirectional stream - async fn handle_stream( - mut send: iroh::endpoint::SendStream, - mut recv: iroh::endpoint::RecvStream, - context: Option, - allowed_addresses: Vec
, - wallet: Wallet, - ) -> Result<()> { - // Handle multiple messages in sequence - let mut is_authorized = false; - let mut current_challenge: Option = None; - - loop { - let Ok(request) = Self::read_message(&mut recv).await else { - break; - }; - - // Handle the request - let response = match request.message { - P2PMessage::Ping { nonce, .. } => { - info!("Received ping with nonce: {nonce}"); - P2PResponse::new( - request.id, - P2PMessage::Pong { - timestamp: SystemTime::now(), - nonce, - }, - ) - } - P2PMessage::RequestAuthChallenge { message } => { - // Generate a fresh cryptographically secure challenge message for this auth attempt - let challenge_bytes: [u8; 32] = rand_v8::rngs::OsRng.gen(); - let challenge_message = hex::encode(challenge_bytes); - - debug!("Received request auth challenge"); - let signature = match sign_message(&message, &wallet).await { - Ok(signature) => signature, - Err(e) => { - error!("Failed to sign message: {e}"); - return Err(anyhow::anyhow!("Failed to sign message: {}", e)); - } - }; - - // Store the challenge message in nonce cache to prevent replay - NONCE_CACHE.insert(challenge_message.clone(), SystemTime::now()); - - // Store the current challenge for this connection - current_challenge = Some(challenge_message.clone()); - - P2PResponse::new( - request.id, - P2PMessage::AuthChallenge { - message: challenge_message, - signed_message: signature, - }, - ) - } - P2PMessage::AuthSolution { signed_message } => { - // Get the challenge message for this connection - debug!("Received auth solution"); - let Some(challenge_message) = ¤t_challenge else { - warn!("No active challenge for auth solution"); - let response = P2PResponse::new(request.id, P2PMessage::AuthRejected {}); - Self::write_response(&mut send, response).await?; - continue; - }; - - // Check if challenge message has been used before (replay attack prevention) - if !NONCE_CACHE.contains_key(challenge_message) { - warn!("Challenge message not found or expired: {challenge_message}"); - let response = P2PResponse::new(request.id, P2PMessage::AuthRejected {}); - Self::write_response(&mut send, response).await?; - continue; - } - - // Clean up old nonces (older than 5 minutes) - let cutoff_time = SystemTime::now() - Duration::from_secs(300); - NONCE_CACHE.retain(|_, &mut timestamp| timestamp > cutoff_time); - - // Parse the signature - let Ok(parsed_signature) = - alloy::primitives::Signature::from_str(&signed_message) - else { - // Handle signature parsing error - let response = P2PResponse::new(request.id, P2PMessage::AuthRejected {}); - Self::write_response(&mut send, response).await?; - continue; - }; - - // Recover address from the challenge message that the client signed - let Ok(recovered_address) = - parsed_signature.recover_address_from_msg(challenge_message) - else { - // Handle address recovery error - let response = P2PResponse::new(request.id, P2PMessage::AuthRejected {}); - Self::write_response(&mut send, response).await?; - continue; - }; - - // Check if the recovered address is in allowed addresses - NONCE_CACHE.remove(challenge_message); - current_challenge = None; - if allowed_addresses.contains(&recovered_address) { - is_authorized = true; - P2PResponse::new(request.id, P2PMessage::AuthGranted {}) - } else { - P2PResponse::new(request.id, P2PMessage::AuthRejected {}) - } - } - P2PMessage::HardwareChallenge { challenge, .. } if is_authorized => { - info!("Received hardware challenge"); - let challenge_response = calc_matrix(&challenge); - P2PResponse::new( - request.id, - P2PMessage::HardwareChallengeResponse { - response: challenge_response, - timestamp: SystemTime::now(), - }, - ) - } - P2PMessage::Invite(invite) if is_authorized => { - if let Some(context) = &context { - let (status, error) = Self::handle_invite(invite, context).await; - P2PResponse::new(request.id, P2PMessage::InviteResponse { status, error }) - } else { - P2PResponse::new( - request.id, - P2PMessage::InviteResponse { - status: "error".to_string(), - error: Some("No context".to_string()), - }, - ) - } - } - P2PMessage::GetTaskLogs if is_authorized => { - if let Some(context) = &context { - let logs = context.docker_service.get_logs().await; - let response_logs = logs - .map(|log_string| vec![log_string]) - .map_err(|e| e.to_string()); - P2PResponse::new( - request.id, - P2PMessage::GetTaskLogsResponse { - logs: response_logs, - }, - ) - } else { - P2PResponse::new( - request.id, - P2PMessage::GetTaskLogsResponse { logs: Ok(vec![]) }, - ) - } - } - P2PMessage::RestartTask if is_authorized => { - if let Some(context) = &context { - let result = context.docker_service.restart_task().await; - let response_result = result.map_err(|e| e.to_string()); - P2PResponse::new( - request.id, - P2PMessage::RestartTaskResponse { - result: response_result, - }, - ) - } else { - P2PResponse::new( - request.id, - P2PMessage::RestartTaskResponse { result: Ok(()) }, - ) - } - } - _ => { - warn!("Unexpected message type"); - continue; - } - }; - - // Send response - Self::write_response(&mut send, response).await?; - } - - Ok(()) - } - - async fn handle_invite( - invite: InviteRequest, - context: &P2PContext, - ) -> (String, Option) { - if context.system_state.is_running().await { - return ( - "error".to_string(), - Some("Heartbeat is currently running and in a compute pool".to_string()), - ); - } - if let Some(pool_id) = context.system_state.compute_pool_id.clone() { - if invite.pool_id.to_string() != pool_id { - return ("error".to_string(), Some("Invalid pool ID".to_string())); - } - } - - let invite_bytes = match hex::decode(&invite.invite) { - Ok(bytes) => bytes, - Err(err) => { - error!("Failed to decode invite hex string: {err:?}"); - return ( - "error".to_string(), - Some("Invalid invite format".to_string()), - ); - } - }; - - if invite_bytes.len() < 65 { - return ( - "error".to_string(), - Some("Invite data is too short".to_string()), - ); - } - - let contracts = &context.contracts; - let wallet = &context.node_wallet; - let pool_id = U256::from(invite.pool_id); - - let bytes_array: [u8; 65] = match invite_bytes[..65].try_into() { - Ok(array) => array, - Err(_) => { - error!("Failed to convert invite bytes to fixed-size array"); - return ( - "error".to_string(), - Some("Invalid invite signature format".to_string()), - ); - } - }; - - let provider_address = context.provider_wallet.wallet.default_signer().address(); - - let pool_info = match contracts.compute_pool.get_pool_info(pool_id).await { - Ok(info) => info, - Err(err) => { - error!("Failed to get pool info: {err:?}"); - return ( - "error".to_string(), - Some("Failed to get pool information".to_string()), - ); - } - }; - - if let PoolStatus::PENDING = pool_info.status { - Console::user_error("Pool is pending - Invite is invalid"); - return ( - "error".to_string(), - Some("Pool is pending - Invite is invalid".to_string()), - ); - } - - let node_address = vec![wallet.wallet.default_signer().address()]; - let signatures = vec![FixedBytes::from(&bytes_array)]; - let nonces = vec![invite.nonce]; - let expirations = vec![invite.expiration]; - let call = match contracts.compute_pool.build_join_compute_pool_call( - pool_id, - provider_address, - node_address, - nonces, - expirations, - signatures, - ) { - Ok(call) => call, - Err(err) => { - error!("Failed to build join compute pool call: {err:?}"); - return ( - "error".to_string(), - Some("Failed to build join compute pool call".to_string()), - ); - } - }; - let provider = &context.provider_wallet.provider; - match retry_call(call, 3, provider.clone(), None).await { - Ok(result) => { - Console::section("WORKER JOINED COMPUTE POOL"); - Console::success(&format!( - "Successfully registered on chain with tx: {result}" - )); - Console::info( - "Status", - "Worker is now part of the compute pool and ready to receive tasks", - ); - } - Err(err) => { - error!("Failed to join compute pool: {err:?}"); - return ( - "error".to_string(), - Some(format!("Failed to join compute pool: {err}")), - ); - } - } - let endpoint = if let Some(url) = &invite.master_url { - format!("{url}/heartbeat") - } else { - match (&invite.master_ip, &invite.master_port) { - (Some(ip), Some(port)) => format!("http://{ip}:{port}/heartbeat"), - _ => { - error!("Missing master IP or port in invite request"); - return ( - "error".to_string(), - Some("Missing master IP or port".to_string()), - ); - } - } - }; - - if let Err(err) = context.heartbeat_service.start(endpoint).await { - error!("Failed to start heartbeat service: {err:?}"); - return ( - "error".to_string(), - Some("Failed to start heartbeat service".to_string()), - ); - } - - ("ok".to_string(), None) - } -} - -#[cfg(test)] -mod tests { - use rand_v8::Rng; - use serial_test::serial; - use shared::p2p::P2PClient; - use url::Url; - - use super::*; - - async fn setup_test_service( - include_addresses: bool, - ) -> (P2PService, P2PClient, Address, Address) { - let validator_wallet = shared::web3::wallet::Wallet::new( - "0000000000000000000000000000000000000000000000000000000000000001", - Url::parse("https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161").unwrap(), - ) - .unwrap(); - let worker_wallet = shared::web3::wallet::Wallet::new( - "0000000000000000000000000000000000000000000000000000000000000002", - Url::parse("https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161").unwrap(), - ) - .unwrap(); - let validator_wallet_address = validator_wallet.wallet.default_signer().address(); - let worker_wallet_address = worker_wallet.wallet.default_signer().address(); - let service = P2PService::new( - None, - CancellationToken::new(), - None, - worker_wallet, - if include_addresses { - vec![validator_wallet_address] - } else { - vec![] - }, - ) - .await - .unwrap(); - let client = P2PClient::new(validator_wallet.clone()).await.unwrap(); - ( - service, - client, - validator_wallet_address, - worker_wallet_address, - ) - } - - #[tokio::test] - #[serial] - async fn test_ping() { - let (service, client, _, worker_wallet_address) = setup_test_service(true).await; - let node_id = service.node_id().to_string(); - let addresses = service.listening_addresses().to_vec(); - let random_nonce = rand_v8::thread_rng().gen::(); - - tokio::spawn(async move { - service.start().unwrap(); - }); - - let ping = P2PMessage::Ping { - nonce: random_nonce, - timestamp: SystemTime::now(), - }; - - let response = client - .send_request(&node_id, &addresses, worker_wallet_address, ping, 20) - .await - .unwrap(); - - let response_nonce = match response { - P2PMessage::Pong { nonce, .. } => nonce, - _ => panic!("Expected Pong message"), - }; - assert_eq!(response_nonce, random_nonce); - } - #[tokio::test] - #[serial] - async fn test_auth_error() { - let (service, client, _, worker_wallet_address) = setup_test_service(false).await; - let node_id = service.node_id().to_string(); - let addresses = service.listening_addresses().to_vec(); - - tokio::spawn(async move { - service.start().unwrap(); - }); - - let ping = P2PMessage::Ping { - nonce: rand_v8::thread_rng().gen::(), - timestamp: SystemTime::now(), - }; - - // Since we set include_addresses to false, the client's wallet address - // is not in the allowed_addresses list, so we expect auth to be rejected - let result = client - .send_request(&node_id, &addresses, worker_wallet_address, ping, 20) - .await; - - assert!( - result.is_err(), - "Expected auth to be rejected but request succeeded" - ); - } -} diff --git a/crates/worker/src/state/system_state.rs b/crates/worker/src/state/system_state.rs index fd8f0a3a..39955de8 100644 --- a/crates/worker/src/state/system_state.rs +++ b/crates/worker/src/state/system_state.rs @@ -1,8 +1,8 @@ +use anyhow::bail; use anyhow::Result; use directories::ProjectDirs; use log::debug; use log::error; -use log::warn; use serde::{Deserialize, Serialize}; use std::fs; use std::path::Path; @@ -10,9 +10,6 @@ use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; -use crate::utils::p2p::generate_iroh_node_id_from_seed; -use crate::utils::p2p::generate_random_seed; - const STATE_FILENAME: &str = "heartbeat_state.toml"; fn get_default_state_dir() -> Option { @@ -23,8 +20,29 @@ fn get_default_state_dir() -> Option { #[derive(Debug, Clone, Serialize, Deserialize)] struct PersistedSystemState { endpoint: Option, - p2p_seed: Option, - worker_p2p_seed: Option, + #[serde( + serialize_with = "serialize_keypair", + deserialize_with = "deserialize_keypair" + )] + p2p_keypair: p2p::Keypair, +} + +fn serialize_keypair(keypair: &p2p::Keypair, serializer: S) -> Result +where + S: serde::Serializer, +{ + let serialized = keypair + .to_protobuf_encoding() + .map_err(serde::ser::Error::custom)?; + serializer.serialize_bytes(&serialized) +} + +fn deserialize_keypair<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let serialized: Vec = Deserialize::deserialize(deserializer)?; + p2p::Keypair::from_protobuf_encoding(&serialized).map_err(serde::de::Error::custom) } #[derive(Debug, Clone)] @@ -34,28 +52,26 @@ pub(crate) struct SystemState { endpoint: Arc>>, state_dir_overwrite: Option, disable_state_storing: bool, - pub compute_pool_id: Option, - - pub worker_p2p_seed: Option, - pub p2p_id: Option, - pub p2p_seed: Option, + compute_pool_id: u32, + p2p_keypair: p2p::Keypair, } impl SystemState { pub(crate) fn new( state_dir: Option, disable_state_storing: bool, - compute_pool_id: Option, - ) -> Self { + compute_pool_id: u32, + ) -> Result { let default_state_dir = get_default_state_dir(); debug!("Default state dir: {default_state_dir:?}"); let state_path = state_dir .map(PathBuf::from) .or_else(|| default_state_dir.map(PathBuf::from)); debug!("State path: {state_path:?}"); + let mut endpoint = None; - let mut p2p_seed: Option = None; - let mut worker_p2p_seed: Option = None; + let mut p2p_keypair = None; + // Try to load state, log info if creating new file if !disable_state_storing { if let Some(path) = &state_path { @@ -67,78 +83,52 @@ impl SystemState { } else if let Ok(Some(loaded_state)) = SystemState::load_state(path) { debug!("Loaded previous state from {state_file:?}"); endpoint = loaded_state.endpoint; - p2p_seed = loaded_state.p2p_seed; - worker_p2p_seed = loaded_state.worker_p2p_seed; + p2p_keypair = Some(loaded_state.p2p_keypair); } else { - debug!("Failed to load state from {state_file:?}"); + bail!("failed to load state from {state_file:?}"); } } } - if p2p_seed.is_none() { - let seed = generate_random_seed(); - p2p_seed = Some(seed); - } - // Generate p2p_id from seed if available - - let p2p_id: Option = - p2p_seed.and_then(|seed| match generate_iroh_node_id_from_seed(seed) { - Ok(id) => Some(id), - Err(_) => { - warn!("Failed to generate p2p_id from seed"); - None - } - }); - if worker_p2p_seed.is_none() { - let seed = generate_random_seed(); - worker_p2p_seed = Some(seed); + if p2p_keypair.is_none() { + p2p_keypair = Some(p2p::Keypair::generate_ed25519()); } - Self { + Ok(Self { last_heartbeat: Arc::new(RwLock::new(None)), is_running: Arc::new(RwLock::new(false)), endpoint: Arc::new(RwLock::new(endpoint)), state_dir_overwrite: state_path.clone(), disable_state_storing, compute_pool_id, - p2p_seed, - p2p_id, - worker_p2p_seed, - } + p2p_keypair: p2p_keypair.expect("p2p keypair must be Some at this point"), + }) } + fn save_state(&self, heartbeat_endpoint: Option) -> Result<()> { if !self.disable_state_storing { debug!("Saving state"); if let Some(state_dir) = &self.state_dir_overwrite { - // Get values without block_on - debug!("Saving p2p_seed: {:?}", self.p2p_seed); - - // Ensure p2p_seed is valid before creating state - if let Some(seed) = self.p2p_seed { - let state = PersistedSystemState { - endpoint: heartbeat_endpoint, - p2p_seed: Some(seed), - worker_p2p_seed: self.worker_p2p_seed, - }; - - debug!("state: {state:?}"); - - fs::create_dir_all(state_dir)?; - let state_path = state_dir.join(STATE_FILENAME); - - // Use JSON serialization instead of TOML - match serde_json::to_string_pretty(&state) { - Ok(json_string) => { - fs::write(&state_path, json_string)?; - debug!("Saved state to {state_path:?}"); - } - Err(e) => { - error!("Failed to serialize state: {e}"); - return Err(anyhow::anyhow!("Failed to serialize state: {}", e)); - } + let state = PersistedSystemState { + endpoint: heartbeat_endpoint, + p2p_keypair: self.p2p_keypair.clone(), + }; + + debug!("state: {state:?}"); + + fs::create_dir_all(state_dir)?; + let state_path = state_dir.join(STATE_FILENAME); + + // Use JSON serialization instead of TOML + match serde_json::to_string_pretty(&state) { + Ok(json_string) => { + fs::write(&state_path, json_string)?; + debug!("Saved state to {state_path:?}"); + } + Err(e) => { + error!("Failed to serialize state: {e}"); + return Err(anyhow::anyhow!("Failed to serialize state: {}", e)); } - } else { - warn!("Cannot save state: p2p_seed is None"); } } } @@ -152,20 +142,23 @@ impl SystemState { match serde_json::from_str(&contents) { Ok(state) => return Ok(Some(state)), Err(e) => { - debug!("Error parsing state file: {e}"); - return Ok(None); + bail!("failed to parse state file: {e}"); } } } Ok(None) } - pub(crate) fn get_p2p_seed(&self) -> Option { - self.p2p_seed + pub(crate) fn get_compute_pool_id(&self) -> u32 { + self.compute_pool_id } - pub(crate) fn get_p2p_id(&self) -> Option { - self.p2p_id.clone() + pub(crate) fn get_p2p_keypair(&self) -> &p2p::Keypair { + &self.p2p_keypair + } + + pub(crate) fn get_p2p_id(&self) -> p2p::PeerId { + self.p2p_keypair.public().to_peer_id() } pub(crate) async fn update_last_heartbeat(&self) { @@ -238,9 +231,9 @@ mod tests { let state = SystemState::new( Some(temp_dir.path().to_string_lossy().to_string()), false, - None, - ); - assert!(state.p2p_id.is_some()); + 0, + ) + .unwrap(); let _ = state .set_running(true, Some("http://localhost:8080/heartbeat".to_string())) .await; @@ -263,30 +256,33 @@ mod tests { let state_file = temp_dir.path().join(STATE_FILENAME); fs::write(&state_file, "invalid_toml_content").expect("Failed to write to state file"); - let state = SystemState::new( + assert!(SystemState::new( Some(temp_dir.path().to_string_lossy().to_string()), false, - None, - ); - assert!(!(state.is_running().await)); - assert_eq!(state.get_heartbeat_endpoint().await, None); + 0, + ) + .is_err()); } #[tokio::test] async fn test_load_state() { + let keypair = p2p::Keypair::generate_ed25519(); + let state = PersistedSystemState { + endpoint: Some("http://localhost:8080/heartbeat".to_string()), + p2p_keypair: keypair, + }; + let serialized = serde_json::to_string_pretty(&state).unwrap(); + let temp_dir = setup_test_dir(); let state_file = temp_dir.path().join(STATE_FILENAME); - fs::write( - &state_file, - r#"{"endpoint": "http://localhost:8080/heartbeat"}"#, - ) - .expect("Failed to write to state file"); + fs::write(&state_file, serialized).unwrap(); let state = SystemState::new( Some(temp_dir.path().to_string_lossy().to_string()), false, - None, - ); + 0, + ) + .unwrap(); assert_eq!( state.get_heartbeat_endpoint().await, Some("http://localhost:8080/heartbeat".to_string()) diff --git a/crates/worker/src/utils/logging.rs b/crates/worker/src/utils/logging.rs index 18c8de4b..312d565c 100644 --- a/crates/worker/src/utils/logging.rs +++ b/crates/worker/src/utils/logging.rs @@ -75,10 +75,6 @@ pub fn setup_logging(cli: Option<&Cli>) -> Result<(), Box u64 { - rand_v8::thread_rng().gen() -} - -// Generate an Iroh node ID from a seed -pub(crate) fn generate_iroh_node_id_from_seed(seed: u64) -> Result> { - // Create a deterministic RNG from the seed - let mut rng = StdRng::seed_from_u64(seed); - - // Generate the secret key using Iroh's method - // This matches exactly how it's done in your Node implementation - let secret_key = SecretKey::generate(&mut rng); - - // Get the node ID (public key) as a string - let node_id = secret_key.public().to_string(); - - Ok(node_id) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generate_random_seed() { - let seed1 = generate_random_seed(); - let seed2 = generate_random_seed(); - - assert_ne!(seed1, seed2); - } - - #[test] - fn test_known_generation() { - let seed: u32 = 848364385; - let result = generate_iroh_node_id_from_seed(seed as u64).unwrap(); - assert_eq!( - result, - "6ba970180efbd83909282ac741085431f54aa516e1783852978bd529a400d0e9" - ); - assert_eq!(result.len(), 64); - } - - #[test] - fn test_deterministic_generation() { - // Same seed should generate same node_id - let seed = generate_random_seed(); - println!("seed: {}", seed); - let result1 = generate_iroh_node_id_from_seed(seed).unwrap(); - let result2 = generate_iroh_node_id_from_seed(seed).unwrap(); - println!("result1: {}", result1); - - assert_eq!(result1, result2); - } -}