diff --git a/.gitignore b/.gitignore index b2d8069..ee3dd5f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ /target/ /.vscode/ -best-agent.json -fitness-plot.svg \ No newline at end of file +fitness-plot.svg +network.json \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 0a9a53b..1d675b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,21 +10,21 @@ checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -41,21 +41,55 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.19" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core", + "quote", + "syn", +] [[package]] name = "either" -version = "1.9.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "genetic-rs" -version = "0.5.4" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a68bb62a836f6ea3261d77cfec4012316e206f53e7d0eab519f5f3630e86001f" +checksum = "ba4095966caf1ba9e16f0b3a6b3c58468ce21d3fd4beccf207f141fc325e0802" dependencies = [ "genetic-rs-common", "genetic-rs-macros", @@ -63,21 +97,22 @@ dependencies = [ [[package]] name = "genetic-rs-common" -version = "0.5.4" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be7aaffd4e4dc82d11819d40794f089c37d02595a401f229ed2877d1a4c401d" +checksum = "49d7c66e226c1c506c3948d1bb799b59141a8b388d7188c2091ef1c69a2aaeba" dependencies = [ + "itertools", "rand", "rayon", - "replace_with", ] [[package]] name = "genetic-rs-macros" -version = "0.5.4" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e73b1f36ea3e799232e1a3141a2765fa6ee9ed7bb3fed96ccfb3bf272d1832e" +checksum = "a5a20679fa28498b37ba820d1fdf1c7d948b5fd47333608a3e336dd63a7c12c5" dependencies = [ + "darling", "genetic-rs-common", "proc-macro2", "quote", @@ -86,13 +121,29 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.12" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "libc", - "wasi", + "r-efi", + "wasip2", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", ] [[package]] @@ -109,9 +160,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.169" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "memchr" @@ -132,48 +183,57 @@ dependencies = [ "serde", "serde-big-array", "serde_json", + "serde_path_to_error", ] [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] [[package]] name = "proc-macro2" -version = "1.0.91" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "307e3004becf10f5a6e0d59d20f3cd28231b0e0827a96cd3e0ce6d14bc1e4bb3" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.35" +version = "1.0.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "rand" -version = "0.8.5" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "libc", "rand_chacha", "rand_core", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", "rand_core", @@ -181,18 +241,18 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.6.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" dependencies = [ "getrandom", ] [[package]] name = "rayon" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" dependencies = [ "either", "rayon-core", @@ -200,9 +260,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -210,22 +270,17 @@ dependencies = [ [[package]] name = "replace_with" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3a8614ee435691de62bcffcf4a66d91b3594bf1428a5722e79103249a095690" - -[[package]] -name = "ryu" -version = "1.0.19" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "51743d3e274e2b18df81c4dc6caf8a5b8e15dbe799e0dca05c7617380094e884" [[package]] name = "serde" -version = "1.0.217" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] @@ -238,11 +293,20 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -251,21 +315,39 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", - "ryu", "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "syn" -version = "2.0.89" +version = "2.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" dependencies = [ "proc-macro2", "quote", @@ -274,12 +356,47 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + +[[package]] +name = "zerocopy" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +name = "zmij" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" diff --git a/Cargo.toml b/Cargo.toml index 4b26e0f..4c342f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,20 +17,35 @@ rustdoc-args = ["--cfg", "docsrs"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[example]] +name = "basic" +path = "examples/basic.rs" +required-features = ["genetic-rs/derive"] + +[[example]] +name = "extra_genes" +path = "examples/extra_genes.rs" +required-features = ["genetic-rs/derive"] + +[[example]] +name = "serde" +path = "examples/serde.rs" +required-features = ["serde"] + [features] default = [] serde = ["dep:serde", "dep:serde-big-array"] - [dependencies] atomic_float = "1.1.0" -bitflags = "2.8.0" -genetic-rs = { version = "0.5.4", features = ["rayon", "derive"] } +bitflags = "2.10.0" +genetic-rs = { version = "1.1.0", features = ["rayon"] } lazy_static = "1.5.0" -rayon = "1.10.0" -replace_with = "0.1.7" -serde = { version = "1.0.217", features = ["derive"], optional = true } +rayon = "1.11.0" +replace_with = "0.1.8" +serde = { version = "1.0.228", features = ["derive"], optional = true } serde-big-array = { version = "0.5.1", optional = true } [dev-dependencies] -serde_json = "1.0.138" \ No newline at end of file +serde_json = "1.0.149" +serde_path_to_error = "0.1.20" diff --git a/README.md b/README.md index 4e9828b..52d2741 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,86 @@ Implementation of the NEAT algorithm using `genetic-rs`. *Do you like this crate and want to support it? If so, leave a ⭐* # How To Use -TODO +The `NeuralNetwork` struct is the main type exported by this crate. The `I` is the number of input neurons, and `O` is the number of output neurons. It implements `GenerateRandom`, `RandomlyMutable`, `Mitosis`, and `Crossover`, with a lot of customizability. This means that you can use it standalone as your organism's entire genome: +```rust +use neat::*; + +fn fitness(net: &NeuralNetwork<5, 6>) -> f32 { + // ideally you'd test multiple times for consistency, + // but this is just a simple example. + // it's also generally good to normalize your inputs between -1..1, + // but NEAT is usually flexible enough to still work anyways + let inputs = [1.0, 2.0, 3.0, 4.0, 5.0]; + let outputs = net.predict(inputs); + + // simple fitness: sum of outputs + // you should replace this with a real fitness test + outputs.iter().sum() +} + +fn main() { + let mut rng = rand::rng(); + let mut sim = GeneticSim::new( + Vec::gen_random(&mut rng, 100), + FitnessEliminator::new_with_default(fitness), + CrossoverRepopulator::new(0.25, ReproductionSettings::default()), + ); + + sim.perform_generations(100); +} +``` + +Or just a part of a more complex genome: +```rust,ignore +use neat::*; + +#[derive(Clone, Debug)] +struct PhysicalStats { + strength: f32, + speed: f32, + // ... +} + +// ... implement `RandomlyMutable`, `GenerateRandom`, `Crossover`, etc. + +#[derive(Clone, Debug, GenerateRandom, RandomlyMutable, Mitosis, Crossover)] +#[randmut(create_context = MyGenomeCtx)] +#[crossover(with_context = MyGenomeCtx)] +struct MyGenome { + brain: NeuralNetwork<4, 2>, + stats: PhysicalStats, +} + +impl Default for MyGenomeCtx { + fn default() -> Self { + Self { + brain: ReproductionSettings::default(), + stats: PhysicalStats::default(), + } + } +} + +fn fitness(genome: &MyGenome) -> f32 { + let inputs = [1.0, 2.0, 3.0, 4.0]; + let outputs = genome.brain.predict(inputs); + // fitness uses both brain output and stats + outputs.iter().sum::() + genome.stats.strength + genome.stats.speed +} + +// main is the exact same as before +fn main() { + let mut rng = rand::rng(); + let mut sim = GeneticSim::new( + Vec::gen_random(&mut rng, 100), + FitnessEliminator::new_with_default(fitness), + CrossoverRepopulator::new(0.25, MyGenomeCtx::default()), + ); + + sim.perform_generations(100); +} +``` + +If you want more in-depth examples, look at the [examples](https://github.com/HyperCodec/neat/tree/main/examples). You can also check out the [genetic-rs docs](https://docs.rs/genetic_rs) to see what other options you have to customize your genetic simulation. ### License This crate falls under the `MIT` license diff --git a/examples/basic.rs b/examples/basic.rs index 85f58cb..74dbe91 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,3 +1,74 @@ +use neat::*; + +// approximate the to_degrees function, which should be pretty +// hard for a traditional network to learn since it's not really close to -1..1 mapping. +fn fitness(net: &NeuralNetwork<1, 1>) -> f32 { + let mut rng = rand::rng(); + let mut total_fitness = 0.0; + + // it's good practice to test on multiple inputs to get a more accurate fitness score + for _ in 0..100 { + let input = rng.random_range(-10.0..10.0); + let output = net.predict([input])[0]; + let expected_output = input.to_degrees(); + + // basically just using negative error as fitness. + // percentage error doesn't work as well here since + // expected_output can be either very small or very large in magnitude. + total_fitness -= (output - expected_output).abs(); + } + + total_fitness +} + fn main() { - todo!("use NeuralNetwork as the entire DNA"); + let mut rng = rand::rng(); + + let mut sim = GeneticSim::new( + Vec::gen_random(&mut rng, 250), + FitnessEliminator::new_with_default(fitness), + CrossoverRepopulator::new(0.25, ReproductionSettings::default()), + ); + + for i in 0..=150 { + sim.next_generation(); + + // sample a genome to print its fitness. + // this value should approach 0 as the generations go on, since the fitness is negative error. + // with the way CrossoverRepopulator (and all builtin repopulators) works internally, the parent genomes + // (i.e. prev generation champs) are more likely to be at the start of the genomes vector. + let sample = &sim.genomes[0]; + let fit = fitness(sample); + println!("Gen {i} sample fitness: {fit}"); + } + println!("Training complete, now you can test the network!"); + + let net = &sim.genomes[0]; + println!("Network in use: {:#?}", net); + + loop { + let mut input_text = String::new(); + println!("Enter a number to convert to degrees (or 'exit' to quit): "); + std::io::stdin().read_line(&mut input_text).unwrap(); + let input_text = input_text.trim(); + if input_text.eq_ignore_ascii_case("exit") { + break; + } + let input: f32 = match input_text.parse() { + Ok(num) => num, + Err(_) => { + println!("Invalid input, please enter a valid number."); + continue; + } + }; + + let output = net.predict([input])[0]; + let expected_output = input.to_degrees(); + println!( + "Network output: {}, Expected output: {}, Error: {}", + output, + expected_output, + (output - expected_output).abs() + ); + } } diff --git a/examples/extra_dna.rs b/examples/extra_dna.rs deleted file mode 100644 index 038709f..0000000 --- a/examples/extra_dna.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - todo!("use AgentDNA with additional params") -} diff --git a/examples/extra_genes.rs b/examples/extra_genes.rs new file mode 100644 index 0000000..d4e6f01 --- /dev/null +++ b/examples/extra_genes.rs @@ -0,0 +1,372 @@ +use neat::*; +use std::f32::consts::PI; + +// ========================================================================== +// SIMULATION CONSTANTS - Adjust these to experiment with different dynamics +// ========================================================================== + +// World/Environment Settings +const WORLD_WIDTH: f32 = 800.0; +const WORLD_HEIGHT: f32 = 600.0; +const INITIAL_FOOD_COUNT: usize = 20; +const FOOD_RESPAWN_THRESHOLD: usize = 10; +const FOOD_DETECTION_DISTANCE: f32 = 10.0; + +// Energy/Food Settings +const BASE_FOOD_ENERGY: f32 = 20.0; // Energy from each food item +const STRENGTH_ENERGY_MULTIPLIER: f32 = 10.0; // Extra energy per strength stat +const MOVEMENT_ENERGY_COST: f32 = 0.2; // Cost per unit of movement +const IDLE_ENERGY_COST: f32 = 0.1; // Cost per timestep just existing + +// Fitness Settings +const FITNESS_PER_FOOD: f32 = 100.0; // Points per food eaten + +// Physical Stats - Min/Max Bounds +const SPEED_MIN: f32 = 0.5; +const SPEED_MAX: f32 = 6.0; +const STRENGTH_MIN: f32 = 0.2; +const STRENGTH_MAX: f32 = 4.0; +const SENSE_RANGE_MIN: f32 = 30.0; +const SENSE_RANGE_MAX: f32 = 250.0; +const ENERGY_CAPACITY_MIN: f32 = 50.0; +const ENERGY_CAPACITY_MAX: f32 = 400.0; + +// Physical Stats - Initial Generation Range +const SPEED_INIT_MIN: f32 = 1.0; +const SPEED_INIT_MAX: f32 = 5.0; +const STRENGTH_INIT_MIN: f32 = 0.5; +const STRENGTH_INIT_MAX: f32 = 3.0; +const SENSE_RANGE_INIT_MIN: f32 = 50.0; +const SENSE_RANGE_INIT_MAX: f32 = 200.0; +const ENERGY_CAPACITY_INIT_MIN: f32 = 100.0; +const ENERGY_CAPACITY_INIT_MAX: f32 = 300.0; + +// Mutation Settings +const SPEED_MUTATION_PROB: f32 = 0.3; +const SPEED_MUTATION_RANGE: f32 = 0.5; +const STRENGTH_MUTATION_PROB: f32 = 0.2; +const STRENGTH_MUTATION_RANGE: f32 = 0.3; +const SENSE_MUTATION_PROB: f32 = 0.2; +const SENSE_MUTATION_RANGE: f32 = 20.0; +const CAPACITY_MUTATION_PROB: f32 = 0.2; +const CAPACITY_MUTATION_RANGE: f32 = 30.0; + +// Genetic Algorithm Settings +const POPULATION_SIZE: usize = 150; +const HIGHEST_GENERATION: usize = 250; +const SIMULATION_TIMESTEPS: usize = 500; +const MUTATION_RATE: f32 = 0.3; + +/// Mutation settings for physical stats +#[derive(Clone, Debug)] +struct PhysicalStatsMutationSettings { + speed_prob: f32, + speed_range: f32, + strength_prob: f32, + strength_range: f32, + sense_prob: f32, + sense_range: f32, + capacity_prob: f32, + capacity_range: f32, +} + +impl Default for PhysicalStatsMutationSettings { + fn default() -> Self { + Self { + speed_prob: SPEED_MUTATION_PROB, + speed_range: SPEED_MUTATION_RANGE, + strength_prob: STRENGTH_MUTATION_PROB, + strength_range: STRENGTH_MUTATION_RANGE, + sense_prob: SENSE_MUTATION_PROB, + sense_range: SENSE_MUTATION_RANGE, + capacity_prob: CAPACITY_MUTATION_PROB, + capacity_range: CAPACITY_MUTATION_RANGE, + } + } +} + +/// Physical traits/stats for an organism +#[derive(Clone, Debug, PartialEq)] +struct PhysicalStats { + /// Speed multiplier (faster = longer strides but more energy cost) + speed: f32, + /// Strength stat (affects energy from food) + strength: f32, + /// Sense range (how far it can detect food) + sense_range: f32, + /// Energy capacity (larger = can go longer without food) + energy_capacity: f32, +} + +impl PhysicalStats { + fn clamp(&mut self) { + self.speed = self.speed.clamp(SPEED_MIN, SPEED_MAX); + self.strength = self.strength.clamp(STRENGTH_MIN, STRENGTH_MAX); + self.sense_range = self.sense_range.clamp(SENSE_RANGE_MIN, SENSE_RANGE_MAX); + self.energy_capacity = self + .energy_capacity + .clamp(ENERGY_CAPACITY_MIN, ENERGY_CAPACITY_MAX); + } +} + +impl GenerateRandom for PhysicalStats { + fn gen_random(rng: &mut impl rand::Rng) -> Self { + let mut stats = PhysicalStats { + speed: rng.random_range(SPEED_INIT_MIN..SPEED_INIT_MAX), + strength: rng.random_range(STRENGTH_INIT_MIN..STRENGTH_INIT_MAX), + sense_range: rng.random_range(SENSE_RANGE_INIT_MIN..SENSE_RANGE_INIT_MAX), + energy_capacity: rng.random_range(ENERGY_CAPACITY_INIT_MIN..ENERGY_CAPACITY_INIT_MAX), + }; + stats.clamp(); + stats + } +} + +impl RandomlyMutable for PhysicalStats { + type Context = PhysicalStatsMutationSettings; + + fn mutate(&mut self, context: &Self::Context, _severity: f32, rng: &mut impl rand::Rng) { + if rng.random::() < context.speed_prob { + self.speed += rng.random_range(-context.speed_range..context.speed_range); + } + if rng.random::() < context.strength_prob { + self.strength += rng.random_range(-context.strength_range..context.strength_range); + } + if rng.random::() < context.sense_prob { + self.sense_range += rng.random_range(-context.sense_range..context.sense_range); + } + if rng.random::() < context.capacity_prob { + self.energy_capacity += + rng.random_range(-context.capacity_range..context.capacity_range); + } + self.clamp(); + } +} + +impl Crossover for PhysicalStats { + type Context = PhysicalStatsMutationSettings; + + fn crossover( + &self, + other: &Self, + context: &Self::Context, + _severity: f32, + rng: &mut impl rand::Rng, + ) -> Self { + let mut child = PhysicalStats { + speed: (self.speed + other.speed) / 2.0 + + rng.random_range(-context.speed_range..context.speed_range), + strength: (self.strength + other.strength) / 2.0 + + rng.random_range(-context.strength_range..context.strength_range), + sense_range: (self.sense_range + other.sense_range) / 2.0 + + rng.random_range(-context.sense_range..context.sense_range), + energy_capacity: (self.energy_capacity + other.energy_capacity) / 2.0 + + rng.random_range(-context.capacity_range..context.capacity_range), + }; + child.clamp(); + child + } +} + +/// A complete organism genome containing both neural network and physical traits +#[derive(Clone, Debug, PartialEq, GenerateRandom, RandomlyMutable, Crossover)] +#[randmut(create_context = OrganismCtx)] +#[crossover(with_context = OrganismCtx)] +struct OrganismGenome { + brain: NeuralNetwork<8, 2>, + stats: PhysicalStats, +} + +/// Running instance of an organism with current position and energy +struct OrganismInstance { + genome: OrganismGenome, + x: f32, + y: f32, + angle: f32, + energy: f32, + lifetime: usize, + food_eaten: usize, +} + +impl OrganismInstance { + fn new(genome: OrganismGenome) -> Self { + let energy = genome.stats.energy_capacity; + Self { + genome, + x: rand::random::() * WORLD_WIDTH, + y: rand::random::() * WORLD_HEIGHT, + angle: rand::random::() * 2.0 * PI, + energy, + lifetime: 0, + food_eaten: 0, + } + } + + /// Simulate one timestep: sense food, decide movement, consume energy, age + fn step(&mut self, food_sources: &[(f32, f32)]) { + self.lifetime += 1; + + // find nearest food + let mut nearest_food_dist = f32::INFINITY; + let mut nearest_food_angle = 0.0; + let mut nearest_food_x_diff = 0.0; + let mut nearest_food_y_diff = 0.0; + + for &(fx, fy) in food_sources { + let dx = fx - self.x; + let dy = fy - self.y; + let dist = (dx * dx + dy * dy).sqrt(); + + if dist < self.genome.stats.sense_range && dist < nearest_food_dist { + nearest_food_dist = dist; + nearest_food_angle = (dy.atan2(dx) - self.angle).sin(); + nearest_food_x_diff = (dx / 100.0).clamp(-1.0, 1.0); + nearest_food_y_diff = (dy / 100.0).clamp(-1.0, 1.0); + } + } + + let sense_food = if nearest_food_dist < self.genome.stats.sense_range { + 1.0 + } else { + 0.0 + }; + + // Create inputs for neural network: + // 0: current energy level (0-1) + // 1: food detected (0 or 1) + // 2: nearest food angle (normalized) + // 3: nearest food x diff + // 4: nearest food y diff + // 5: speed stat (normalized) + // 6: energy capacity (normalized) + // 7: age (slow-paced, up to 1 at age 1000) + let inputs = [ + (self.energy / self.genome.stats.energy_capacity).clamp(0.0, 1.0), + sense_food, + nearest_food_angle, + nearest_food_x_diff, + nearest_food_y_diff, + (self.genome.stats.speed / 5.0).clamp(0.0, 1.0), + (self.genome.stats.energy_capacity / 200.0).clamp(0.0, 1.0), + (self.lifetime as f32 / 1000.0).clamp(0.0, 1.0), + ]; + + // get movement outputs from neural network + let outputs = self.genome.brain.predict(inputs); + let move_forward = (outputs[0] * self.genome.stats.speed).clamp(-5.0, 5.0); + let turn = (outputs[1] * PI / 4.0).clamp(-PI / 8.0, PI / 8.0); + + // update position and angle + self.angle += turn; + self.x += move_forward * self.angle.cos(); + self.y += move_forward * self.angle.sin(); + + // wrap around world + if self.x < 0.0 { + self.x += WORLD_WIDTH; + } else if self.x >= WORLD_WIDTH { + self.x -= WORLD_WIDTH; + } + if self.y < 0.0 { + self.y += WORLD_HEIGHT; + } else if self.y >= WORLD_HEIGHT { + self.y -= WORLD_HEIGHT; + } + + // consume energy for movement + let movement_cost = (move_forward.abs() / self.genome.stats.speed).max(0.5); + self.energy -= movement_cost * MOVEMENT_ENERGY_COST; + + // consume energy for existing + self.energy -= IDLE_ENERGY_COST; + } + + /// Check if organism lands on food and consume it + fn eat(&mut self, food_sources: &mut Vec<(f32, f32)>) { + food_sources.retain(|&(fx, fy)| { + let dx = fx - self.x; + let dy = fy - self.y; + let dist = (dx * dx + dy * dy).sqrt(); + if dist < FOOD_DETECTION_DISTANCE { + // ate food + self.energy += + BASE_FOOD_ENERGY + (self.genome.stats.strength * STRENGTH_ENERGY_MULTIPLIER); + self.energy = self.energy.min(self.genome.stats.energy_capacity); + self.food_eaten += 1; + false + } else { + true + } + }); + } + + fn is_alive(&self) -> bool { + self.energy > 0.0 + } + + fn fitness(&self) -> f32 { + let food_fitness = (self.food_eaten as f32) * FITNESS_PER_FOOD; + food_fitness + } +} + +/// Evaluate an organism's fitness by running a simulation +fn evaluate_organism(genome: &OrganismGenome) -> f32 { + let mut rng = rand::rng(); + + let mut food_sources: Vec<(f32, f32)> = (0..INITIAL_FOOD_COUNT) + .map(|_| { + ( + rng.random_range(0.0..WORLD_WIDTH), + rng.random_range(0.0..WORLD_HEIGHT), + ) + }) + .collect(); + + let mut instance = OrganismInstance::new(genome.clone()); + + for _ in 0..SIMULATION_TIMESTEPS { + if instance.is_alive() { + instance.step(&food_sources); + instance.eat(&mut food_sources); + } + + // respawn food + if food_sources.len() < FOOD_RESPAWN_THRESHOLD { + food_sources.push(( + rng.random_range(0.0..WORLD_WIDTH), + rng.random_range(0.0..WORLD_HEIGHT), + )); + } + } + + instance.fitness() +} + +fn main() { + let mut rng = rand::rng(); + + println!("Starting genetic NEAT simulation with physical traits"); + println!("Population: {} organisms", POPULATION_SIZE); + println!("Each has: Neural Network Brain + Physical Stats (Speed, Strength, Sense Range, Energy Capacity)\n"); + + let mut sim = GeneticSim::new( + Vec::gen_random(&mut rng, POPULATION_SIZE), + FitnessEliminator::new_with_default(evaluate_organism), + CrossoverRepopulator::new(MUTATION_RATE, OrganismCtx::default()), + ); + + for generation in 0..=HIGHEST_GENERATION { + sim.next_generation(); + + let sample = &sim.genomes[0]; + let fitness = evaluate_organism(sample); + + println!( + "Gen {}: Sample fitness: {:.1} | Speed: {:.2}, Strength: {:.2}, Sense: {:.1}, Capacity: {:.1}", + generation, fitness, sample.stats.speed, sample.stats.strength, sample.stats.sense_range, sample.stats.energy_capacity + ); + } + + println!("\nSimulation complete!"); +} diff --git a/examples/serde.rs b/examples/serde.rs new file mode 100644 index 0000000..90b4e81 --- /dev/null +++ b/examples/serde.rs @@ -0,0 +1,36 @@ +use neat::{activation::register_activation, *}; + +const OUTPUT_PATH: &str = "network.json"; + +fn magic_activation(x: f32) -> f32 { + // just a random activation function to show that it gets serialized and deserialized correctly. + (x * 2.0).sin() +} + +fn main() { + // custom activation functions must be registered before deserialization, since the network needs to know how to deserialize them. + register_activation(activation_fn!(magic_activation)); + + let mut rng = rand::rng(); + let mut net = NeuralNetwork::<10, 10>::new(&mut rng); + + println!("Mutating network..."); + + for _ in 0..100 { + net.mutate(&MutationSettings::default(), 0.25, &mut rng); + } + + let file = + std::fs::File::create(OUTPUT_PATH).expect("Failed to create file for network output"); + serde_json::to_writer_pretty(file, &net).expect("Failed to write network to file"); + + println!("Network saved to {OUTPUT_PATH}"); + + // reopen because for some reason io hates working properly with both read and write + // (even when using OpenOptions) + let file = std::fs::File::open(OUTPUT_PATH).expect("Failed to open network file for reading"); + let net2: NeuralNetwork<10, 10> = + serde_json::from_reader(file).expect("Failed to parse network from file"); + assert_eq!(net, net2); + println!("Network successfully loaded from file and matches original!"); +} diff --git a/src/activation.rs b/src/activation.rs index af9f74e..84018b2 100644 --- a/src/activation.rs +++ b/src/activation.rs @@ -4,6 +4,7 @@ pub mod builtin; use bitflags::bitflags; use builtin::*; +use genetic_rs::prelude::rand; #[cfg(feature = "serde")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -20,11 +21,11 @@ use crate::NeuronLocation; #[macro_export] macro_rules! activation_fn { ($F: path) => { - ActivationFn::new(std::sync::Arc::new($F), NeuronScope::default(), stringify!($F).into()) + $crate::activation::ActivationFn::new(std::sync::Arc::new($F), $crate::activation::NeuronScope::default(), stringify!($F).into()) }; ($F: path, $S: expr) => { - ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F).into()) + $crate::activation::ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F).into()) }; {$($F: path),*} => { @@ -56,13 +57,13 @@ pub fn batch_register_activation(acts: impl IntoIterator) { /// A registry of the different possible activation functions. pub struct ActivationRegistry { /// The currently-registered activation functions. - pub fns: HashMap, + pub fns: HashMap<&'static str, ActivationFn>, } impl ActivationRegistry { /// Registers an activation function. pub fn register(&mut self, activation: ActivationFn) { - self.fns.insert(activation.name.clone(), activation); + self.fns.insert(activation.name, activation); } /// Registers multiple activation functions at once. @@ -72,7 +73,7 @@ impl ActivationRegistry { } } - /// Gets a Vec of all the activation functions registered. Unless you need an owned value, use [fns][ActivationRegistry::fns].values() instead. + /// Gets a Vec of all the activation functions registered. Use [fns][ActivationRegistry::fns] if you only need an iterator. pub fn activations(&self) -> Vec { self.fns.values().cloned().collect() } @@ -82,9 +83,35 @@ impl ActivationRegistry { let acts = self.activations(); acts.into_iter() - .filter(|a| !a.scope.contains(NeuronScope::NONE) && a.scope.contains(scope)) + .filter(|a| a.scope.contains(scope)) .collect() } + + /// Clears all existing values in the activation registry. + pub fn clear(&mut self) { + self.fns.clear(); + } + + /// Fetches a random activation fn that applies to the provided scope. + pub fn random_activation_in_scope( + &self, + scope: NeuronScope, + rng: &mut impl rand::Rng, + ) -> ActivationFn { + let mut iter = self.fns.values().cycle(); + let num_iterations = rng.random_range(0..self.fns.len() - 1); + + for _ in 0..num_iterations { + iter.next().unwrap(); + } + + let mut val = iter.next().unwrap(); + while !val.scope.contains(scope) { + val = iter.next().unwrap(); + } + + val.clone() + } } impl Default for ActivationRegistry { @@ -125,19 +152,25 @@ pub struct ActivationFn { /// The scope defining where the activation function can appear. pub scope: NeuronScope, - pub(crate) name: String, + + /// The name of the activation function, used for debugging and serialization. + pub name: &'static str, } impl ActivationFn { /// Creates a new ActivationFn object. - pub fn new(func: Arc, scope: NeuronScope, name: String) -> Self { + pub fn new( + func: Arc, + scope: NeuronScope, + name: &'static str, + ) -> Self { Self { func, name, scope } } } impl fmt::Debug for ActivationFn { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "{}", self.name) + write!(f, "{}", self.name) } } @@ -150,7 +183,7 @@ impl PartialEq for ActivationFn { #[cfg(feature = "serde")] impl Serialize for ActivationFn { fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.name) + serializer.serialize_str(self.name) } } @@ -164,7 +197,7 @@ impl<'a> Deserialize<'a> for ActivationFn { let reg = ACTIVATION_REGISTRY.read().unwrap(); - let f = reg.fns.get(&name); + let f = reg.fns.get(name.as_str()); if f.is_none() { panic!("Activation function {name} not found"); diff --git a/src/lib.rs b/src/lib.rs index 0de7360..572c7af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,4 @@ -//! A crate implementing NeuroEvolution of Augmenting Topologies (NEAT). -//! -//! The goal is to provide a simple-to-use, very dynamic [`NeuralNetwork`] type that -//! integrates directly into the [`genetic-rs`](https://crates.io/crates/genetic-rs) ecosystem. -//! -//! Look at the README, docs, or examples to learn how to use this crate. - +#![doc = include_str!("../README.md")] #![warn(missing_docs)] /// Contains the types surrounding activation functions. @@ -17,5 +11,31 @@ pub use neuralnet::*; pub use genetic_rs::{self, prelude::*}; +/// A trait for getting the index of the maximum element. +pub trait MaxIndex { + /// Returns the index of the maximum element. + fn max_index(self) -> Option; +} + +impl> MaxIndex for I { + fn max_index(self) -> Option { + // enumerate now so we don't accidentally + // skip the index of the first element + let mut iter = self.enumerate(); + + let mut max_i = 0; + let mut max_v = iter.next()?.1; + + for (i, v) in iter { + if v > max_v { + max_v = v; + max_i = i; + } + } + + Some(max_i) + } +} + #[cfg(test)] mod tests; diff --git a/src/neuralnet.rs b/src/neuralnet.rs index cce0d61..1aa174b 100644 --- a/src/neuralnet.rs +++ b/src/neuralnet.rs @@ -1,5 +1,6 @@ use std::{ - collections::HashSet, + collections::{HashMap, HashSet, VecDeque}, + ops::{Index, IndexMut}, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, @@ -7,6 +8,7 @@ use std::{ }; use atomic_float::AtomicF32; +use bitflags::bitflags; use genetic_rs::prelude::*; use rand::Rng; use replace_with::replace_with_or_abort; @@ -19,33 +21,33 @@ use crate::{ use rayon::prelude::*; #[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; #[cfg(feature = "serde")] use serde_big_array::BigArray; -/// The mutation settings for [`NeuralNetwork`]. -/// Does not affect [`NeuralNetwork::mutate`], only [`NeuralNetwork::divide`] and [`NeuralNetwork::crossover`]. -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone, PartialEq)] -pub struct MutationSettings { - /// The chance of each mutation type to occur. - pub mutation_rate: f32, - - /// The number of times to try to mutate the network. - pub mutation_passes: usize, - - /// The maximum amount that the weights will be mutated by. - pub weight_mutation_amount: f32, -} +#[cfg(feature = "serde")] +mod outputs_serde { + use super::*; + use std::collections::HashMap; + + pub fn serialize( + map: &HashMap, + serializer: S, + ) -> Result + where + S: Serializer, + { + let vec: Vec<(NeuronLocation, f32)> = map.iter().map(|(k, v)| (*k, *v)).collect(); + vec.serialize(serializer) + } -impl Default for MutationSettings { - fn default() -> Self { - Self { - mutation_rate: 0.01, - mutation_passes: 3, - weight_mutation_amount: 0.5, - } + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let vec: Vec<(NeuronLocation, f32)> = Vec::deserialize(deserializer)?; + Ok(vec.into_iter().collect()) } } @@ -67,20 +69,17 @@ pub struct NeuralNetwork { /// The output layer of neurons. Their values will be returned from [`NeuralNetwork::predict`]. #[cfg_attr(feature = "serde", serde(with = "BigArray"))] pub output_layer: [Neuron; O], - - /// The mutation settings for the network. - pub mutation_settings: MutationSettings, } impl NeuralNetwork { // TODO option to set default output layer activations /// Creates a new random neural network with the given settings. - pub fn new(mutation_settings: MutationSettings, rng: &mut impl Rng) -> Self { + pub fn new(rng: &mut impl rand::Rng) -> Self { let mut output_layer = Vec::with_capacity(O); for _ in 0..O { output_layer.push(Neuron::new_with_activation( - vec![], + HashMap::new(), activation_fn!(sigmoid), rng, )); @@ -89,20 +88,21 @@ impl NeuralNetwork { let mut input_layer = Vec::with_capacity(I); for _ in 0..I { - let mut already_chosen = Vec::new(); - let outputs = (0..rng.gen_range(1..=O)) - .map(|_| { - let mut j = rng.gen_range(0..O); - while already_chosen.contains(&j) { - j = rng.gen_range(0..O); - } - - output_layer[j].input_count += 1; - already_chosen.push(j); - - (NeuronLocation::Output(j), rng.gen()) - }) - .collect(); + let mut already_chosen = HashSet::new(); + let num_outputs = rng.random_range(1..=O); + let mut outputs = HashMap::new(); + + for _ in 0..num_outputs { + let mut j = rng.random_range(0..O); + while already_chosen.contains(&j) { + j = rng.random_range(0..O); + } + + output_layer[j].input_count += 1; + already_chosen.insert(j); + + outputs.insert(NeuronLocation::Output(j), rng.random()); + } input_layer.push(Neuron::new_with_activation( outputs, @@ -118,7 +118,6 @@ impl NeuralNetwork { input_layer, hidden_layers: vec![], output_layer, - mutation_settings, } } @@ -131,12 +130,17 @@ impl NeuralNetwork { .into_par_iter() .for_each(|i| self.eval(NeuronLocation::Input(i), cache.clone())); - cache.output() - } + let mut outputs = [0.0; O]; + for (i, output) in outputs.iter_mut().enumerate().take(O) { + let n = &self.output_layer[i]; + let val = cache.get(NeuronLocation::Output(i)); + *output = n.activate(val); + } - fn eval(&self, loc: impl AsRef, cache: Arc>) { - let loc = loc.as_ref(); + outputs + } + fn eval(&self, loc: NeuronLocation, cache: Arc>) { if !cache.claim(loc) { // some other thread is already // waiting to do this task, currently doing it, or done. @@ -150,74 +154,122 @@ impl NeuralNetwork { rayon::yield_now(); } - let val = cache.get(loc); - let n = self.get_neuron(loc); + let n = &self[loc]; + let val = n.activate(cache.get(loc)); - n.outputs.par_iter().for_each(|(loc2, weight)| { - cache.add(loc2, n.activate(val * weight)); + n.outputs.par_iter().for_each(|(&loc2, weight)| { + cache.add(loc2, val * weight); self.eval(loc2, cache.clone()); }); } /// Get a neuron at the specified [`NeuronLocation`]. - pub fn get_neuron(&self, loc: impl AsRef) -> &Neuron { - match loc.as_ref() { - NeuronLocation::Input(i) => &self.input_layer[*i], - NeuronLocation::Hidden(i) => &self.hidden_layers[*i], - NeuronLocation::Output(i) => &self.output_layer[*i], + pub fn get_neuron(&self, loc: NeuronLocation) -> Option<&Neuron> { + if !self.neuron_exists(loc) { + None + } else { + Some(&self[loc]) + } + } + + /// Returns whether there is a neuron at the location + pub fn neuron_exists(&self, loc: NeuronLocation) -> bool { + match loc { + NeuronLocation::Input(i) => i < I, + NeuronLocation::Hidden(i) => i < self.hidden_layers.len(), + NeuronLocation::Output(i) => i < O, } } /// Get a mutable reference to the neuron at the specified [`NeuronLocation`]. - pub fn get_neuron_mut(&mut self, loc: impl AsRef) -> &mut Neuron { - match loc.as_ref() { - NeuronLocation::Input(i) => &mut self.input_layer[*i], - NeuronLocation::Hidden(i) => &mut self.hidden_layers[*i], - NeuronLocation::Output(i) => &mut self.output_layer[*i], + pub fn get_neuron_mut(&mut self, loc: NeuronLocation) -> Option<&mut Neuron> { + if !self.neuron_exists(loc) { + None + } else { + Some(&mut self[loc]) } } + /// Adds a new neuron to hidden layer. Updates [`input_count`][Neuron::input_count]s automatically. + /// Removes any output connections that point to invalid neurons or would result in cyclic linkage. + /// Returns whether all output connections were valid. + /// Due to the cyclic check, this function has time complexity O(nm), where n is the number of neurons + /// and m is the number of output connections. + pub fn add_neuron(&mut self, mut n: Neuron) -> bool { + let mut valid = true; + let new_loc = NeuronLocation::Hidden(self.hidden_layers.len()); + let outputs = n.outputs.keys().cloned().collect::>(); + for loc in outputs { + if !self.neuron_exists(loc) + || !self.is_connection_safe(Connection { + from: new_loc, + to: loc, + }) + { + n.outputs.remove(&loc); + valid = false; + continue; + } + + let n = &mut self[loc]; + n.input_count += 1; + } + + self.hidden_layers.push(n); + + valid + } + /// Split a [`Connection`] into two of the same weight, joined by a new [`Neuron`] in the hidden layer(s). pub fn split_connection(&mut self, connection: Connection, rng: &mut impl Rng) { - let newloc = NeuronLocation::Hidden(self.hidden_layers.len()); + let new_loc = NeuronLocation::Hidden(self.hidden_layers.len()); - let a = self.get_neuron_mut(connection.from); - let weight = unsafe { a.remove_connection(connection.to) }.unwrap(); + let a = &mut self[connection.from]; + let w = a + .outputs + .remove(&connection.to) + .expect("invalid connection.to"); - a.outputs.push((newloc, weight)); + a.outputs.insert(new_loc, w); - let n = Neuron::new(vec![(connection.to, weight)], NeuronScope::HIDDEN, rng); - self.hidden_layers.push(n); + let mut outputs = HashMap::new(); + outputs.insert(connection.to, w); + let mut new_n = Neuron::new(outputs, NeuronScope::HIDDEN, rng); + new_n.input_count = 1; + self.hidden_layers.push(new_n); } /// Adds a connection but does not check for cyclic linkages. - /// - /// # Safety - /// This is marked as unsafe because it could cause a hang/livelock when predicting due to cyclic linkage. - /// There is no actual UB or unsafe code associated with it. - pub unsafe fn add_connection_raw(&mut self, connection: Connection, weight: f32) { - let a = self.get_neuron_mut(connection.from); - a.outputs.push((connection.to, weight)); + pub fn add_connection_unchecked(&mut self, connection: Connection, weight: f32) { + let a = &mut self[connection.from]; + a.outputs.insert(connection.to, weight); - // let b = self.get_neuron_mut(connection.to); - // b.inputs.insert(connection.from); + let b = &mut self[connection.to]; + b.input_count += 1; } - /// Returns false if the connection is cyclic. + /// Returns false if the connection is cyclic or the input/output neurons are otherwise invalid in some other way. + /// Can be O(n) over the number of neurons in the network. pub fn is_connection_safe(&self, connection: Connection) -> bool { + if connection.from.is_output() + || connection.to.is_input() + || connection.from == connection.to + || (self.neuron_exists(connection.from) + && self[connection.from].outputs.contains_key(&connection.to)) + { + return false; + } let mut visited = HashSet::from([connection.from]); - self.dfs(&mut visited, connection.to) } - // TODO maybe parallelize fn dfs(&self, visited: &mut HashSet, current: NeuronLocation) -> bool { if !visited.insert(current) { return false; } - let n = self.get_neuron(current); - for (loc, _) in &n.outputs { + let n = &self[current]; + for loc in n.outputs.keys() { if !self.dfs(visited, *loc) { return false; } @@ -226,32 +278,104 @@ impl NeuralNetwork { true } - /// Safe, checked add connection method. Returns false if it aborted connecting due to cyclic linkage. + /// Safe, checked add connection method. Returns false if it aborted due to cyclic linkage. + /// Note that checking for cyclic linkage is O(n) over all neurons in the network, which + /// may be expensive for larger networks. pub fn add_connection(&mut self, connection: Connection, weight: f32) -> bool { if !self.is_connection_safe(connection) { return false; } - unsafe { - self.add_connection_raw(connection, weight); - } + self.add_connection_unchecked(connection, weight); true } + /// Attempts to add a random connection, retrying if unsafe. + /// Returns the connection if it established one before reaching max_retries. + pub fn add_random_connection( + &mut self, + max_retries: usize, + rng: &mut impl rand::Rng, + ) -> Option { + for _ in 0..max_retries { + let a = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); + let b = self.random_location_in_scope(rng, !NeuronScope::INPUT); + + let conn = Connection { from: a, to: b }; + if self.add_connection(conn, rng.random()) { + return Some(conn); + } + } + + None + } + + /// Attempts to get a random connection, retrying if the neuron it found + /// doesn't have any outbound connections. + /// Returns the connection if it found one before reaching max_retries. + pub fn get_random_connection( + &mut self, + max_retries: usize, + rng: &mut impl rand::Rng, + ) -> Option { + for _ in 0..max_retries { + let a = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); + let an = &self[a]; + if an.outputs.is_empty() { + continue; + } + + let mut iter = an + .outputs + .keys() + .skip(rng.random_range(0..an.outputs.len())); + let b = iter.next().unwrap(); + + let conn = Connection { from: a, to: *b }; + return Some(conn); + } + + None + } + + /// Attempts to remove a random connection, retrying if the neuron it found + /// doesn't have any outbound connections. Also removes hanging neurons created + /// by removing the connection. + /// + /// Returns the connection if it removed one before reaching max_retries. + pub fn remove_random_connection( + &mut self, + max_retries: usize, + rng: &mut impl rand::Rng, + ) -> Option { + if let Some(conn) = self.get_random_connection(max_retries, rng) { + self.remove_connection(conn); + Some(conn) + } else { + None + } + } + /// Mutates a connection's weight. - pub fn mutate_weight(&mut self, connection: Connection, rng: &mut impl Rng) { - let rate = self.mutation_settings.weight_mutation_amount; - let n = self.get_neuron_mut(connection.from); - n.mutate_weight(connection.to, rate, rng).unwrap(); + pub fn mutate_weight(&mut self, connection: Connection, amount: f32, rng: &mut impl Rng) { + let n = &mut self[connection.from]; + n.mutate_weight(connection.to, amount, rng).unwrap(); } /// Get a random valid location within the network. pub fn random_location(&self, rng: &mut impl Rng) -> NeuronLocation { - match rng.gen_range(0..3) { - 0 => NeuronLocation::Input(rng.gen_range(0..self.input_layer.len())), - 1 => NeuronLocation::Hidden(rng.gen_range(0..self.hidden_layers.len())), - 2 => NeuronLocation::Output(rng.gen_range(0..self.output_layer.len())), + if self.hidden_layers.is_empty() { + if rng.random_range(0..=1) != 0 { + return NeuronLocation::Input(rng.random_range(0..I)); + } + return NeuronLocation::Output(rng.random_range(0..O)); + } + + match rng.random_range(0..3) { + 0 => NeuronLocation::Input(rng.random_range(0..I)), + 1 => NeuronLocation::Hidden(rng.random_range(0..self.hidden_layers.len())), + 2 => NeuronLocation::Output(rng.random_range(0..O)), _ => unreachable!(), } } @@ -259,49 +383,121 @@ impl NeuralNetwork { /// Get a random valid location within a [`NeuronScope`]. pub fn random_location_in_scope( &self, - rng: &mut impl Rng, + rng: &mut impl rand::Rng, scope: NeuronScope, ) -> NeuronLocation { - let loc = self.random_location(rng); + if scope == NeuronScope::NONE { + panic!("cannot select from empty scope"); + } - // this is a lazy and slow way of donig it, TODO better version. - if !scope.contains(NeuronScope::from(loc)) { - return self.random_location_in_scope(rng, scope); + let mut layers = Vec::with_capacity(3); + if scope.contains(NeuronScope::INPUT) { + layers.push((NeuronLocation::Input(0), I)); + } + if scope.contains(NeuronScope::HIDDEN) && !self.hidden_layers.is_empty() { + layers.push((NeuronLocation::Hidden(0), self.hidden_layers.len())); + } + if scope.contains(NeuronScope::OUTPUT) { + layers.push((NeuronLocation::Output(0), O)); } + let (mut loc, size) = layers[rng.random_range(0..layers.len())]; + loc.set_inner(rng.random_range(0..size)); loc } - /// Remove a connection and any hanging neurons caused by the deletion. - /// Returns whether there was a hanging neuron. - pub fn remove_connection(&mut self, connection: Connection) -> bool { - let a = self.get_neuron_mut(connection.from); - unsafe { a.remove_connection(connection.to) }.unwrap(); + /// Remove a connection and indicate whether the destination neuron became hanging + /// (with the exception of output layer neurons). + /// Returns `true` if the destination neuron has input_count == 0 and should be removed. + /// Callers must handle the removal of the destination neuron if needed. + pub fn remove_connection_raw(&mut self, connection: Connection) -> bool { + let a = self + .get_neuron_mut(connection.from) + .expect("invalid connection.from"); + if a.outputs.remove(&connection.to).is_none() { + panic!("invalid connection.to"); + } + + let b = &mut self[connection.to]; - let b = self.get_neuron_mut(connection.to); - b.input_count -= 1; + // if the invariants held at the beginning of the call, + // this should never underflow, but some cases like remove_cycles + // may temporarily break invariants. + b.input_count = b.input_count.saturating_sub(1); - if b.input_count == 0 { - self.remove_neuron(connection.to); + // signal removal + connection.to.is_hidden() && b.input_count == 0 + } + + /// Remove a connection from the network. + /// This will also deal with hanging neurons iteratively to avoid recursion that + /// can invalidate stored indices during nested deletions. + /// This method is preferable to [`remove_connection_raw`][NeuralNetwork::remove_connection_raw] for a majority of usecases, + /// as it preserves the invariants of the neural network. + pub fn remove_connection(&mut self, conn: Connection) -> bool { + if self.remove_connection_raw(conn) { + self.remove_neuron(conn.to); return true; } - false } - /// Remove a neuron and downshift all connection indexes to compensate for it. - pub fn remove_neuron(&mut self, loc: impl AsRef) { - let loc = loc.as_ref(); + /// Remove a neuron and downshift all connection indices to compensate for it. + /// Returns the number of neurons removed that were under the index of the removed neuron (including itself). + /// This will also deal with hanging neurons iteratively to avoid recursion that + /// can invalidate stored indices during nested deletions. + pub fn remove_neuron(&mut self, loc: NeuronLocation) -> usize { if !loc.is_hidden() { - panic!("Can only remove neurons from hidden layer"); + panic!("cannot remove neurons in input or output layer"); } - unsafe { - self.downshift_connections(loc.unwrap()); + let initial_i = loc.unwrap(); + + let mut work = VecDeque::new(); + work.push_back(loc); + + let mut removed = 0; + while let Some(cur_loc) = work.pop_front() { + // if the neuron was already removed due to earlier deletions, skip. + // i don't think it realistically should ever happen, but just in case. + if !self.neuron_exists(cur_loc) { + continue; + } + + let outputs = { + let n = &self[cur_loc]; + n.outputs.keys().cloned().collect::>() + }; + + for target in outputs { + if self.remove_connection_raw(Connection { + from: cur_loc, + to: target, + }) { + // target became hanging; schedule it for removal. + work.push_back(target); + } + } + + // Re-check that the neuron still exists and is hidden before removing. + if !self.neuron_exists(cur_loc) || !cur_loc.is_hidden() { + continue; + } + + let i = cur_loc.unwrap(); + if i < self.hidden_layers.len() { + self.hidden_layers.remove(i); + if i <= initial_i { + removed += 1; + } + self.downshift_connections(i, &mut work); // O(n^2) bad, but we can optimize later if it's a problem. + } } + + removed } - unsafe fn downshift_connections(&mut self, i: usize) { + fn downshift_connections(&mut self, i: usize, work: &mut VecDeque) { self.input_layer .par_iter_mut() .for_each(|n| n.downshift_outputs(i)); @@ -309,28 +505,96 @@ impl NeuralNetwork { self.hidden_layers .par_iter_mut() .for_each(|n| n.downshift_outputs(i)); + + work.par_iter_mut().for_each(|loc| match loc { + NeuronLocation::Hidden(j) if *j > i => *j -= 1, + _ => {} + }); } - // TODO maybe more parallelism and pass Connection info. /// Runs the `callback` on the weights of the neural network in parallel, allowing it to modify weight values. - pub fn map_weights(&mut self, callback: impl Fn(&mut f32) + Sync) { + pub fn update_weights(&mut self, callback: impl Fn(&NeuronLocation, &mut f32) + Sync) { for n in &mut self.input_layer { - n.outputs.par_iter_mut().for_each(|(_, w)| callback(w)); + n.outputs + .par_iter_mut() + .for_each(|(loc, w)| callback(loc, w)); } for n in &mut self.hidden_layers { - n.outputs.par_iter_mut().for_each(|(_, w)| callback(w)); + n.outputs + .par_iter_mut() + .for_each(|(loc, w)| callback(loc, w)); } } - unsafe fn clear_input_counts(&mut self) { - // not sure whether all this parallelism is necessary or if it will just generate overhead - // rayon::scope(|s| { - // s.spawn(|_| self.input_layer.par_iter_mut().for_each(|n| n.input_count = 0)); - // s.spawn(|_| self.hidden_layers.par_iter_mut().for_each(|n| n.input_count = 0)); - // s.spawn(|_| self.output_layer.par_iter_mut().for_each(|n| n.input_count = 0)); - // }); + /// Runs the `callback` on the neurons of the neural network in parallel, allowing it to modify neuron values. + pub fn mutate_neurons(&mut self, callback: impl Fn(&mut Neuron) + Sync) { + self.input_layer.par_iter_mut().for_each(&callback); + self.hidden_layers.par_iter_mut().for_each(&callback); + self.output_layer.par_iter_mut().for_each(&callback); + } + + /// Mutates the activation functions of the neurons in the neural network. + pub fn mutate_activations(&mut self, rate: f32) { + let reg = ACTIVATION_REGISTRY.read().unwrap(); + self.mutate_activations_with_reg(rate, ®); + } + /// Mutates the activation functions of the neurons in the neural network, using a provided registry. + pub fn mutate_activations_with_reg(&mut self, rate: f32, reg: &ActivationRegistry) { + self.input_layer.par_iter_mut().for_each(|n| { + let mut rng = rand::rng(); + if rng.random_bool(rate as f64) { + n.mutate_activation(®.activations_in_scope(NeuronScope::INPUT), &mut rng); + } + }); + self.hidden_layers.par_iter_mut().for_each(|n| { + let mut rng = rand::rng(); + if rng.random_bool(rate as f64) { + n.mutate_activation(®.activations_in_scope(NeuronScope::HIDDEN), &mut rng); + } + }); + self.output_layer.par_iter_mut().for_each(|n| { + let mut rng = rand::rng(); + if rng.random_bool(rate as f64) { + n.mutate_activation(®.activations_in_scope(NeuronScope::OUTPUT), &mut rng); + } + }); + } + + /// Recounts inputs for all neurons in the network + /// and removes any invalid connections. + pub fn reset_input_counts(&mut self) { + self.clear_input_counts(); + + for i in 0..I { + self.reset_inputs_for_neuron(NeuronLocation::Input(i)); + } + + for i in 0..self.hidden_layers.len() { + self.reset_inputs_for_neuron(NeuronLocation::Hidden(i)); + } + } + + fn reset_inputs_for_neuron(&mut self, loc: NeuronLocation) { + let outputs = self[loc].outputs.keys().cloned().collect::>(); + let outputs2 = outputs + .into_iter() + .filter(|&loc| { + if !self.neuron_exists(loc) { + return false; + } + + let target = &mut self[loc]; + target.input_count += 1; + true + }) + .collect::>(); + + self[loc].outputs.retain(|loc, _| outputs2.contains(loc)); + } + + fn clear_input_counts(&mut self) { self.input_layer .par_iter_mut() .for_each(|n| n.input_count = 0); @@ -342,146 +606,374 @@ impl NeuralNetwork { .for_each(|n| n.input_count = 0); } - /// Recalculates the [`input_count`][`Neuron::input_count`] field for all neurons in the network. - pub fn recalculate_input_counts(&mut self) { - unsafe { self.clear_input_counts() }; + /// Iterates over the network and removes any hanging neurons in the hidden layer(s). + pub fn prune_hanging_neurons(&mut self) { + let mut i = 0; + while i < self.hidden_layers.len() { + let mut new_i = i + 1; + if self.hidden_layers[i].input_count == 0 { + // this saturating_sub is a code smell but it works and avoids some edge cases where indices can get messed up. + new_i = new_i.saturating_sub(self.remove_neuron(NeuronLocation::Hidden(i))); + } + i = new_i; + } + } + + /// Uses DFS to find and remove all cycles in O(n+e) time. + /// Expects [`prune_hanging_neurons`][NeuralNetwork::prune_hanging_neurons] to be called afterwards + pub fn remove_cycles(&mut self) { + let mut visited = HashMap::new(); + let mut edges_to_remove: HashSet = HashSet::new(); for i in 0..I { - for j in 0..self.input_layer[i].outputs.len() { - let (loc, _) = self.input_layer[i].outputs[j]; - self.get_neuron_mut(loc).input_count += 1; - } + self.remove_cycles_dfs( + &mut visited, + &mut edges_to_remove, + None, + NeuronLocation::Input(i), + ); } + // unattached cycles (will cause problems since they + // never get deleted by input_count == 0) for i in 0..self.hidden_layers.len() { - for j in 0..self.hidden_layers[i].outputs.len() { - let (loc, _) = self.hidden_layers[i].outputs[j]; - self.get_neuron_mut(loc).input_count += 1; + let loc = NeuronLocation::Hidden(i); + if !visited.contains_key(&loc) { + self.remove_cycles_dfs(&mut visited, &mut edges_to_remove, None, loc); } } + + for conn in edges_to_remove { + // only doing raw here since we recalculate input counts and + // prune hanging neurons later. + self.remove_connection_raw(conn); + } } -} -impl RandomlyMutable for NeuralNetwork { - fn mutate(&mut self, rate: f32, rng: &mut impl Rng) { - if rng.gen::() <= rate { - // split connection - let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); - let n = self.get_neuron(from); - let (to, _) = n.random_output(rng); + // colored dfs + fn remove_cycles_dfs( + &mut self, + visited: &mut HashMap, + edges_to_remove: &mut HashSet, + prev: Option, + current: NeuronLocation, + ) { + if let Some(&existing) = visited.get(¤t) { + if existing == 0 { + // part of current dfs - found a cycle + // prev must exist here since visited would be empty on first call. + let prev = prev.unwrap(); + if self[prev].outputs.contains_key(¤t) { + edges_to_remove.insert(Connection { + from: prev, + to: current, + }); + } + } - self.split_connection(Connection { from, to }, rng); + // already fully visited, no need to check again + return; } - if rng.gen::() <= rate { - // add connection - let weight = rng.gen::(); + visited.insert(current, 0); + + let outputs = self[current].outputs.keys().cloned().collect::>(); + for loc in outputs { + self.remove_cycles_dfs(visited, edges_to_remove, Some(current), loc); + } - let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); - let to = self.random_location_in_scope(rng, !NeuronScope::INPUT); + visited.insert(current, 1); + } - let mut connection = Connection { from, to }; - while !self.add_connection(connection, weight) { - let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); - let to = self.random_location_in_scope(rng, !NeuronScope::INPUT); - connection = Connection { from, to }; + /// Performs just the mutations that modify the graph structure of the neural network, + /// and not the internal mutations that only modify values such as activation functions, weights, and biases. + pub fn perform_graph_mutations( + &mut self, + settings: &MutationSettings, + rate: f32, + rng: &mut impl rand::Rng, + ) { + // TODO maybe allow specifying probability + // for each type of mutation + if settings + .allowed_mutations + .contains(GraphMutations::SPLIT_CONNECTION) + && rng.random_bool(rate as f64) + { + // split connection + if let Some(conn) = self.get_random_connection(settings.max_split_retries, rng) { + self.split_connection(conn, rng); } } - if rng.gen::() <= rate { - // remove connection - - let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); - let a = self.get_neuron(from); - let (to, _) = a.random_output(rng); + if settings + .allowed_mutations + .contains(GraphMutations::ADD_CONNECTION) + && rng.random_bool(rate as f64) + { + // add connection + self.add_random_connection(settings.max_add_retries, rng); + } - self.remove_connection(Connection { from, to }); + if settings + .allowed_mutations + .contains(GraphMutations::REMOVE_CONNECTION) + && rng.random_bool(rate as f64) + { + // remove connection + self.remove_random_connection(settings.max_remove_retries, rng); } + } - self.map_weights(|w| { - // TODO maybe `Send`able rng. - let mut rng = rand::thread_rng(); + /// Performs just the mutations that modify internal values such as activation functions, weights, and biases, + /// and not the graph mutations that modify the structure of the neural network. + pub fn perform_internal_mutations(&mut self, settings: &MutationSettings, rate: f32) { + self.mutate_activations(rate); + self.mutate_weights(settings.weight_mutation_amount); + } - if rng.gen::() <= rate { - *w += rng.gen_range(-rate..rate); - } + /// Same as [`mutate`][NeuralNetwork::mutate] but allows specifying a custom activation registry for activation mutations. + pub fn mutate_with_reg( + &mut self, + settings: &MutationSettings, + rate: f32, + rng: &mut impl rand::Rng, + reg: &ActivationRegistry, + ) { + self.perform_graph_mutations(settings, rate, rng); + self.mutate_activations_with_reg(rate, reg); + self.mutate_weights(settings.weight_mutation_amount); + } + + /// Mutates all weights by a random amount up to `max_amount` in either direction. + pub fn mutate_weights(&mut self, max_amount: f32) { + self.update_weights(|_, w| { + let mut rng = rand::rng(); + let amount = rng.random_range(-max_amount..max_amount); + *w += amount; }); } } -impl DivisionReproduction for NeuralNetwork { - fn divide(&self, rng: &mut impl Rng) -> Self { - let mut child = self.clone(); +impl Index for NeuralNetwork { + type Output = Neuron; - for _ in 0..self.mutation_settings.mutation_passes { - child.mutate(child.mutation_settings.mutation_rate, rng); + fn index(&self, loc: NeuronLocation) -> &Self::Output { + match loc { + NeuronLocation::Input(i) => &self.input_layer[i], + NeuronLocation::Hidden(i) => &self.hidden_layers[i], + NeuronLocation::Output(i) => &self.output_layer[i], } + } +} - child +impl GenerateRandom for NeuralNetwork { + fn gen_random(rng: &mut impl rand::Rng) -> Self { + Self::new(rng) } } -#[allow(clippy::needless_range_loop)] -impl CrossoverReproduction for NeuralNetwork { - fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self { - let mut output_layer = self.output_layer.clone(); +impl IndexMut for NeuralNetwork { + fn index_mut(&mut self, loc: NeuronLocation) -> &mut Self::Output { + match loc { + NeuronLocation::Input(i) => &mut self.input_layer[i], + NeuronLocation::Hidden(i) => &mut self.hidden_layers[i], + NeuronLocation::Output(i) => &mut self.output_layer[i], + } + } +} - for (i, n) in output_layer.iter_mut().enumerate() { - if rng.gen::() >= 0.5 { - *n = other.output_layer[i].clone(); - } +/// The mutation settings for [`NeuralNetwork`]. +/// Does not affect [`NeuralNetwork::mutate`], only [`NeuralNetwork::divide`] and [`NeuralNetwork::crossover`]. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq)] +pub struct MutationSettings { + /// The chance of each mutation type to occur. + pub mutation_rate: f32, + + /// The maximum amount that the weights will be mutated by in one mutation pass. + pub weight_mutation_amount: f32, + + /// The maximum amount that biases will be mutated by in one mutation pass. + pub bias_mutation_amount: f32, + + /// The maximum number of retries for adding connections. + pub max_add_retries: usize, + + /// The maximum number of retries for removing connections. + pub max_remove_retries: usize, + + /// The maximum number of retries for splitting connections. + pub max_split_retries: usize, + + /// The types of graph mutations to allow during mutation. + /// Graph mutations are mutations that modify the structure of the neural network, + /// such as adding/removing connections and adding neurons. + pub allowed_mutations: GraphMutations, +} + +impl Default for MutationSettings { + fn default() -> Self { + Self { + mutation_rate: 0.01, + weight_mutation_amount: 0.5, + bias_mutation_amount: 0.5, + max_add_retries: 10, + max_remove_retries: 10, + max_split_retries: 10, + allowed_mutations: GraphMutations::default(), } + } +} + +bitflags! { + /// The types of graph mutations to allow during mutation. + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub struct GraphMutations: u8 { + /// Mutation that splits an existing connection into two via a hidden neuron. + const SPLIT_CONNECTION = 0b00000001; + /// Mutation that adds a new connection between neurons. + const ADD_CONNECTION = 0b00000010; + /// Mutation that removes an existing connection. + const REMOVE_CONNECTION = 0b00000100; + } +} + +impl Default for GraphMutations { + fn default() -> Self { + Self::all() + } +} - let hidden_len = self.hidden_layers.len().max(other.hidden_layers.len()); - let mut hidden_layers = Vec::with_capacity(hidden_len); +#[cfg(feature = "serde")] +impl Serialize for GraphMutations { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.bits().serialize(serializer) + } +} - for i in 0..hidden_len { - if rng.gen::() >= 0.5 { - if let Some(n) = self.hidden_layers.get(i) { - let mut n = n.clone(); - n.prune_invalid_outputs(hidden_len, O); +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for GraphMutations { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let bits = u8::deserialize(deserializer)?; + GraphMutations::from_bits(bits) + .ok_or_else(|| serde::de::Error::custom("invalid bit pattern for GraphMutations")) + } +} - hidden_layers[i] = n; +impl RandomlyMutable for NeuralNetwork { + type Context = MutationSettings; - continue; - } - } + fn mutate(&mut self, settings: &MutationSettings, rate: f32, rng: &mut impl Rng) { + let reg = ACTIVATION_REGISTRY.read().unwrap(); + self.mutate_with_reg(settings, rate, rng, ®); + } +} - let mut n = other.hidden_layers[i].clone(); - n.prune_invalid_outputs(hidden_len, O); +/// The settings used for [`NeuralNetwork`] reproduction. +#[derive(Debug, Clone, PartialEq)] +pub struct ReproductionSettings { + /// The mutation settings to use during reproduction. + pub mutation: MutationSettings, - hidden_layers[i] = n; + /// The number of times to apply mutation during reproduction. + pub mutation_passes: usize, +} + +impl Default for ReproductionSettings { + fn default() -> Self { + Self { + mutation: MutationSettings::default(), + mutation_passes: 3, + } + } +} + +impl Mitosis for NeuralNetwork { + type Context = ReproductionSettings; + + fn divide( + &self, + settings: &ReproductionSettings, + rate: f32, + rng: &mut impl prelude::Rng, + ) -> Self { + let mut child = self.clone(); + + for _ in 0..settings.mutation_passes { + child.mutate(&settings.mutation, rate, rng); } - let mut input_layer = self.input_layer.clone(); + child + } +} + +impl Crossover for NeuralNetwork { + type Context = ReproductionSettings; + + fn crossover( + &self, + other: &Self, + settings: &ReproductionSettings, + rate: f32, + rng: &mut impl rand::Rng, + ) -> Self { + // merge (temporarily breaking invariants) and then resolve invariants. + let mut child = NeuralNetwork { + input_layer: self.input_layer.clone(), + hidden_layers: vec![], + output_layer: self.output_layer.clone(), + }; + + for i in 0..I { + if rng.random_bool(0.5) { + child.input_layer[i] = other.input_layer[i].clone(); + } + } - for (i, n) in input_layer.iter_mut().enumerate() { - if rng.gen::() >= 0.5 { - *n = other.input_layer[i].clone(); + for i in 0..O { + if rng.random_bool(0.5) { + child.output_layer[i] = other.output_layer[i].clone(); } - n.prune_invalid_outputs(hidden_len, O); } - // crossover mutation settings just in case. - let mutation_settings = if rng.gen::() >= 0.5 { - self.mutation_settings.clone() + let larger; + let smaller; + if self.hidden_layers.len() >= other.hidden_layers.len() { + larger = &self.hidden_layers; + smaller = &other.hidden_layers; } else { - other.mutation_settings.clone() - }; + larger = &other.hidden_layers; + smaller = &self.hidden_layers; + } - let mut child = Self { - input_layer, - hidden_layers, - output_layer, - mutation_settings, - }; + for i in 0..larger.len() { + if i < smaller.len() { + if rng.random_bool(0.5) { + child.hidden_layers.push(smaller[i].clone()); + } else { + child.hidden_layers.push(larger[i].clone()); + } + continue; + } + + // larger is the only one with spare neurons, add them. + child.hidden_layers.push(larger[i].clone()); + } - // TODO maybe find a way to do this while doing crossover stuff instead of recalculating everything. - // would be annoying to implement though. - child.recalculate_input_counts(); + // resolve invariants + child.remove_cycles(); + child.reset_input_counts(); + child.prune_hanging_neurons(); - for _ in 0..child.mutation_settings.mutation_passes { - child.mutate(child.mutation_settings.mutation_rate, rng); + for _ in 0..settings.mutation_passes { + child.mutate(&settings.mutation, rate, rng); } child @@ -498,7 +990,7 @@ fn output_exists(loc: NeuronLocation, hidden_len: usize, output_len: usize) -> b /// A helper struct for operations on connections between neurons. /// It does not contain information about the weight. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Connection { /// The source of the connection. @@ -516,7 +1008,8 @@ pub struct Neuron { pub input_count: usize, /// The connections and weights to other neurons. - pub outputs: Vec<(NeuronLocation, f32)>, + #[cfg_attr(feature = "serde", serde(with = "outputs_serde"))] + pub outputs: HashMap, /// The initial value of the neuron. pub bias: f32, @@ -528,14 +1021,14 @@ pub struct Neuron { impl Neuron { /// Creates a new neuron with a specified activation function and outputs. pub fn new_with_activation( - outputs: Vec<(NeuronLocation, f32)>, + outputs: HashMap, activation_fn: ActivationFn, rng: &mut impl Rng, ) -> Self { Self { input_count: 0, outputs, - bias: rng.gen(), + bias: rng.random(), activation_fn, } } @@ -543,26 +1036,23 @@ impl Neuron { /// Creates a new neuron with the given output locations. /// Chooses a random activation function within the specified scope. pub fn new( - outputs: Vec<(NeuronLocation, f32)>, - current_scope: NeuronScope, + outputs: HashMap, + scope: NeuronScope, rng: &mut impl Rng, ) -> Self { let reg = ACTIVATION_REGISTRY.read().unwrap(); - let activations = reg.activations_in_scope(current_scope); + let act = reg.random_activation_in_scope(scope, rng); - Self::new_with_activations(outputs, activations, rng) + Self::new_with_activation(outputs, act, rng) } /// Creates a new neuron with the given outputs. /// Takes a collection of activation functions and chooses a random one from them to use. pub fn new_with_activations( - outputs: Vec<(NeuronLocation, f32)>, - activations: impl IntoIterator, + outputs: HashMap, + activations: &[ActivationFn], rng: &mut impl Rng, ) -> Self { - // TODO get random in iterator form - let mut activations: Vec<_> = activations.into_iter().collect(); - // TODO maybe Result instead. if activations.is_empty() { panic!("Empty activations list provided"); @@ -570,7 +1060,7 @@ impl Neuron { Self::new_with_activation( outputs, - activations.remove(rng.gen_range(0..activations.len())), + activations[rng.random_range(0..activations.len())].clone(), rng, ) } @@ -580,56 +1070,16 @@ impl Neuron { self.activation_fn.func.activate(v) } - /// Get the weight of the provided output location. Returns `None` if not found. - pub fn get_weight(&self, output: impl AsRef) -> Option { - let loc = *output.as_ref(); - for out in &self.outputs { - if out.0 == loc { - return Some(out.1); - } - } - - None - } - - /// Tries to remove a connection from the neuron and returns the weight if it was found. - /// - /// # Safety - /// This is marked as unsafe because it will not update the destination's [`input_count`][Neuron::input_count]. - /// Similar to [`add_connection_raw`][NeuralNetwork::add_connection_raw], this does not mean UB or anything. - pub unsafe fn remove_connection(&mut self, output: impl AsRef) -> Option { - let loc = *output.as_ref(); - let mut i = 0; - - while i < self.outputs.len() { - if self.outputs[i].0 == loc { - return Some(self.outputs.remove(i).1); - } - i += 1; - } - - None - } - /// Randomly mutates the specified weight with the rate. pub fn mutate_weight( &mut self, - output: impl AsRef, + output: NeuronLocation, rate: f32, rng: &mut impl Rng, ) -> Option { - let loc = *output.as_ref(); - let mut i = 0; - - while i < self.outputs.len() { - let o = &mut self.outputs[i]; - if o.0 == loc { - o.1 += rng.gen_range(-rate..rate); - - return Some(o.1); - } - - i += 1; + if let Some(w) = self.outputs.get_mut(&output) { + *w += rng.random_range(-rate..=rate); + return Some(*w); } None @@ -637,11 +1087,13 @@ impl Neuron { /// Get a random output location and weight. pub fn random_output(&self, rng: &mut impl Rng) -> (NeuronLocation, f32) { - self.outputs[rng.gen_range(0..self.outputs.len())] + // will panic if outputs is empty + let i = rng.random_range(0..self.outputs.len()); + let x = self.outputs.iter().nth(i).unwrap(); + (*x.0, *x.1) } pub(crate) fn downshift_outputs(&mut self, i: usize) { - // TODO par_iter_mut instead of replace replace_with_or_abort(&mut self.outputs, |o| { o.into_par_iter() .map(|(loc, w)| match loc { @@ -655,12 +1107,21 @@ impl Neuron { /// Removes any outputs pointing to a nonexistent neuron. pub fn prune_invalid_outputs(&mut self, hidden_len: usize, output_len: usize) { self.outputs - .retain(|(loc, _)| output_exists(*loc, hidden_len, output_len)); + .retain(|loc, _| output_exists(*loc, hidden_len, output_len)); + } + + /// Replaces the activation function with a random one. + pub fn mutate_activation(&mut self, activations: &[ActivationFn], rng: &mut impl Rng) { + if activations.is_empty() { + panic!("Empty activations list provided"); + } + + self.activation_fn = activations[rng.random_range(0..activations.len())].clone(); } } /// A pseudo-pointer of sorts that is used for caching. -#[derive(Hash, Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Hash, Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum NeuronLocation { /// Points to a neuron in the input layer at contained index. @@ -697,6 +1158,16 @@ impl NeuronLocation { Self::Output(i) => *i, } } + + /// Sets the inner index value without changing the layer. + pub fn set_inner(&mut self, v: usize) { + // there's gotta be a cleaner way of doing this + match self { + Self::Input(i) => *i = v, + Self::Hidden(i) => *i = v, + Self::Output(i) => *i = v, + } + } } impl AsRef for NeuronLocation { @@ -759,11 +1230,11 @@ pub struct NeuralNetCache { impl NeuralNetCache { /// Gets the value of a neuron at the given location. - pub fn get(&self, loc: impl AsRef) -> f32 { - match loc.as_ref() { - NeuronLocation::Input(i) => self.input_layer[*i].value.load(Ordering::SeqCst), - NeuronLocation::Hidden(i) => self.hidden_layers[*i].value.load(Ordering::SeqCst), - NeuronLocation::Output(i) => self.output_layer[*i].value.load(Ordering::SeqCst), + pub fn get(&self, loc: NeuronLocation) -> f32 { + match loc { + NeuronLocation::Input(i) => self.input_layer[i].value.load(Ordering::SeqCst), + NeuronLocation::Hidden(i) => self.hidden_layers[i].value.load(Ordering::SeqCst), + NeuronLocation::Output(i) => self.output_layer[i].value.load(Ordering::SeqCst), } } @@ -811,17 +1282,6 @@ impl NeuralNetCache { } } - /// Fetches and packs the output layer values into an array. - pub fn output(&self) -> [f32; O] { - let output: Vec<_> = self - .output_layer - .par_iter() - .map(|c| c.value.load(Ordering::SeqCst)) - .collect(); - - output.try_into().unwrap() - } - /// Attempts to claim a neuron. Returns false if it has already been claimed. pub fn claim(&self, loc: impl AsRef) -> bool { match loc.as_ref() { diff --git a/src/tests.rs b/src/tests.rs index 825cdee..b1345a2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,179 +1,339 @@ -use crate::*; -use rand::prelude::*; +use std::collections::HashMap; -// no support for tuple structs derive in genetic-rs yet :( -#[derive(Debug, Clone, PartialEq)] -struct Agent(NeuralNetwork<4, 1>); +use crate::{activation::builtin::linear_activation, *}; +use genetic_rs::prelude::rand::{rngs::StdRng, SeedableRng}; +use rayon::prelude::*; -impl Prunable for Agent {} +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum GraphCheckState { + CurrentCycle, + Checked, +} + +fn assert_graph_invariants(net: &NeuralNetwork) { + let mut visited = HashMap::new(); -impl RandomlyMutable for Agent { - fn mutate(&mut self, rate: f32, rng: &mut impl Rng) { - self.0.mutate(rate, rng); + for i in 0..I { + dfs(net, NeuronLocation::Input(i), &mut visited); + } + + for i in 0..net.hidden_layers.len() { + let loc = NeuronLocation::Hidden(i); + if !visited.contains_key(&loc) { + panic!("hanging neuron: {loc:?}"); + } } } -impl DivisionReproduction for Agent { - fn divide(&self, rng: &mut impl rand::Rng) -> Self { - Self(self.0.divide(rng)) +// simple colored dfs for checking graph invariants. +fn dfs( + net: &NeuralNetwork, + loc: NeuronLocation, + visited: &mut HashMap, +) { + if let Some(existing) = visited.get(&loc) { + match *existing { + GraphCheckState::CurrentCycle => panic!("cycle detected on {loc:?}"), + GraphCheckState::Checked => return, + } + } + + visited.insert(loc, GraphCheckState::CurrentCycle); + + for loc2 in net[loc].outputs.keys() { + dfs(net, *loc2, visited); } + + visited.insert(loc, GraphCheckState::Checked); } -impl CrossoverReproduction for Agent { - fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self { - Self(self.0.crossover(&other.0, rng)) +struct InputCountsCache { + hidden_layers: Vec, + output: [usize; O], +} + +impl InputCountsCache { + fn tally(&mut self, loc: NeuronLocation) { + match loc { + NeuronLocation::Input(_) => panic!("input neurons can't have inputs"), + NeuronLocation::Hidden(i) => self.hidden_layers[i] += 1, + NeuronLocation::Output(i) => self.output[i] += 1, + } } } -struct GuessTheNumber(f32); +// asserts that cached/tracked values are correct. mainly only used for +// input count and such +fn assert_cache_consistency(net: &NeuralNetwork) { + let mut cache = InputCountsCache { + hidden_layers: vec![0; net.hidden_layers.len()], + output: [0; O], + }; -impl GuessTheNumber { - fn new(rng: &mut impl Rng) -> Self { - Self(rng.gen()) + for i in 0..I { + let n = &net[NeuronLocation::Input(i)]; + for loc in n.outputs.keys() { + cache.tally(*loc); + } } - fn guess(&self, n: f32) -> Option { - if n > self.0 + 1.0e-5 { - return Some(1.); + for n in &net.hidden_layers { + for loc in n.outputs.keys() { + cache.tally(*loc); } + } - if n < self.0 - 1.0e-5 { - return Some(-1.); + for (i, x) in cache.hidden_layers.into_iter().enumerate() { + if x == 0 { + // redundant because of graph invariants, but better safe than sorry + panic!("found hanging neuron"); } - // guess was correct (or at least within margin of error). - None + assert_eq!(x, net.hidden_layers[i].input_count); } -} -fn fitness(agent: &Agent) -> f32 { - let mut rng = rand::thread_rng(); + for (i, x) in cache.output.into_iter().enumerate() { + assert_eq!(x, net.output_layer[i].input_count); + } +} - let mut fitness = 0.; +fn assert_network_invariants(net: &NeuralNetwork) { + assert_graph_invariants(net); + assert_cache_consistency(net); + // TODO other invariants +} - // 10 games for consistency - for _ in 0..10 { - let game = GuessTheNumber::new(&mut rng); +const TEST_COUNT: u64 = 1000; +fn rng_test(test: impl Fn(&mut StdRng) + Sync) { + (0..TEST_COUNT).into_par_iter().for_each(|seed| { + let mut rng = StdRng::seed_from_u64(seed); + test(&mut rng); + }); +} - let mut last_guess = 0.; - let mut last_result = 0.; +#[test] +fn create_network() { + rng_test(|rng| { + let net = NeuralNetwork::<10, 10>::new(rng); + assert_network_invariants(&net); + }); +} - let mut last_guess_2 = 0.; - let mut last_result_2 = 0.; +#[test] +fn split_connection() { + // rng doesn't matter here since it's just adding bias in eval + let mut rng = StdRng::seed_from_u64(0xabcdef); + + let mut net = NeuralNetwork::<1, 1>::new(&mut rng); + assert_network_invariants(&net); + + net.split_connection( + Connection { + from: NeuronLocation::Input(0), + to: NeuronLocation::Output(0), + }, + &mut rng, + ); + assert_network_invariants(&net); + + assert_eq!( + *net.input_layer[0].outputs.keys().next().unwrap(), + NeuronLocation::Hidden(0) + ); + assert_eq!( + *net.hidden_layers[0].outputs.keys().next().unwrap(), + NeuronLocation::Output(0) + ); +} - let mut steps = 0; - loop { - if steps >= 20 { - // took too many guesses - fitness -= 50.; - break; - } +#[test] +fn add_connection() { + let mut rng = StdRng::seed_from_u64(0xabcdef); + let mut net = NeuralNetwork { + input_layer: [Neuron::new_with_activation( + HashMap::new(), + activation_fn!(linear_activation), + &mut rng, + )], + hidden_layers: vec![], + output_layer: [Neuron::new_with_activation( + HashMap::new(), + activation_fn!(linear_activation), + &mut rng, + )], + }; + assert_network_invariants(&net); - let [cur_guess] = - agent - .0 - .predict([last_guess, last_result, last_guess_2, last_result_2]); + let mut conn = Connection { + from: NeuronLocation::Input(0), + to: NeuronLocation::Output(0), + }; + assert!(net.add_connection(conn, 0.1)); + assert_network_invariants(&net); - let cur_result = game.guess(cur_guess); + assert!(!net.add_connection(conn, 0.1)); + assert_network_invariants(&net); - if let Some(result) = cur_result { - last_guess = last_guess_2; - last_result = last_result_2; + let mut outputs = HashMap::new(); + outputs.insert(NeuronLocation::Output(0), 0.1); + let n = Neuron::new_with_activation(outputs, activation_fn!(linear_activation), &mut rng); - last_guess_2 = cur_guess; - last_result_2 = result; + net.add_neuron(n.clone()); + // temporarily broken invariants bc of hanging neuron - fitness -= 1.; - steps += 1; + conn.to = NeuronLocation::Hidden(0); + assert!(net.add_connection(conn, 0.1)); + assert_network_invariants(&net); - continue; - } + net.add_neuron(n); - fitness += 50.; - break; - } - } + conn.to = NeuronLocation::Hidden(1); + assert!(net.add_connection(conn, 0.1)); + assert_network_invariants(&net); - fitness -} + conn.from = NeuronLocation::Hidden(0); + assert!(net.add_connection(conn, 0.1)); + assert_network_invariants(&net); -#[test] -fn division() { - let mut rng = rand::thread_rng(); + net.split_connection(conn, &mut rng); + assert_network_invariants(&net); - let starting_genomes = (0..100) - .map(|_| Agent(NeuralNetwork::new(MutationSettings::default(), &mut rng))) - .collect(); + conn.from = NeuronLocation::Hidden(2); + conn.to = NeuronLocation::Hidden(0); - let mut sim = GeneticSim::new(starting_genomes, fitness, division_pruning_nextgen); + assert!(!net.add_connection(conn, 0.1)); + assert_network_invariants(&net); - sim.perform_generations(100); + // random stress testing + rng_test(|rng| { + let mut net = NeuralNetwork::<10, 10>::new(rng); + assert_network_invariants(&net); + for _ in 0..50 { + net.add_random_connection(10, rng); + assert_network_invariants(&net); + } + }); } #[test] -fn crossover() { - let mut rng = rand::thread_rng(); - - let starting_genomes = (0..100) - .map(|_| Agent(NeuralNetwork::new(MutationSettings::default(), &mut rng))) - .collect(); - - let mut sim = GeneticSim::new(starting_genomes, fitness, crossover_pruning_nextgen); - - sim.perform_generations(100); +fn remove_connection() { + let mut rng = StdRng::seed_from_u64(0xabcdef); + let mut net = NeuralNetwork { + input_layer: [Neuron::new_with_activation( + HashMap::from([ + (NeuronLocation::Output(0), 0.1), + (NeuronLocation::Hidden(0), 1.0), + ]), + activation_fn!(linear_activation), + &mut rng, + )], + hidden_layers: vec![Neuron { + input_count: 1, + outputs: HashMap::new(), // not sure whether i want neurons with no outputs to break the invariant/be removed + bias: 0.0, + activation_fn: activation_fn!(linear_activation), + }], + output_layer: [Neuron { + input_count: 1, + outputs: HashMap::new(), + bias: 0.0, + activation_fn: activation_fn!(linear_activation), + }], + }; + assert_network_invariants(&net); + + assert!(!net.remove_connection(Connection { + from: NeuronLocation::Input(0), + to: NeuronLocation::Output(0) + })); + assert_network_invariants(&net); + + assert!(net.remove_connection(Connection { + from: NeuronLocation::Input(0), + to: NeuronLocation::Hidden(0) + })); + assert_network_invariants(&net); + + rng_test(|rng| { + let mut net = NeuralNetwork::<10, 10>::new(rng); + assert_network_invariants(&net); + + for _ in 0..70 { + net.add_random_connection(10, rng); + assert_network_invariants(&net); + + if rng.random_bool(0.25) { + // rng allows network to form more complex edge cases. + net.remove_random_connection(5, rng); + // don't need to remove neuron since this + // method handles it automatically. + assert_network_invariants(&net); + } + } + }); } -#[cfg(feature = "serde")] -#[test] -fn serde() { - let mut rng = rand::thread_rng(); - let net: NeuralNetwork<5, 10> = NeuralNetwork::new(MutationSettings::default(), &mut rng); +// TODO remove_neuron test - let text = serde_json::to_string(&net).unwrap(); +const NUM_MUTATIONS: usize = 50; +const MUTATION_RATE: f32 = 0.25; +#[test] +fn mutate() { + rng_test(|rng| { + let mut net = NeuralNetwork::<10, 10>::new(rng); + assert_network_invariants(&net); - let net2: NeuralNetwork<5, 10> = serde_json::from_str(&text).unwrap(); + let settings = MutationSettings::default(); - assert_eq!(net, net2); + for _ in 0..NUM_MUTATIONS { + net.mutate(&settings, MUTATION_RATE, rng); + assert_network_invariants(&net); + } + }); } #[test] -fn neural_net_cache_sync() { - let cache = NeuralNetCache { - input_layer: [NeuronCache::new(0.3, 0), NeuronCache::new(0.25, 0)], - hidden_layers: vec![ - NeuronCache::new(0.2, 2), - NeuronCache::new(0.0, 2), - NeuronCache::new(1.5, 2), - ], - output_layer: [NeuronCache::new(0.0, 3), NeuronCache::new(0.0, 3)], - }; +fn crossover() { + rng_test(|rng| { + let mut net1 = NeuralNetwork::<10, 10>::new(rng); + assert_network_invariants(&net1); - for i in 0..2 { - let input_loc = NeuronLocation::Input(i); + let mut net2 = NeuralNetwork::<10, 10>::new(rng); + assert_network_invariants(&net2); - assert!(cache.claim(&input_loc)); + let settings = ReproductionSettings::default(); - for j in 0..3 { - cache.add( - NeuronLocation::Hidden(j), - f32::tanh(cache.get(&input_loc) * 1.2), - ); - } - } - - for i in 0..3 { - let hidden_loc = NeuronLocation::Hidden(i); + for _ in 0..NUM_MUTATIONS { + let a = net1.crossover(&net2, &settings, MUTATION_RATE, rng); + assert_network_invariants(&a); - assert!(cache.is_ready(&hidden_loc)); - assert!(cache.claim(&hidden_loc)); + let b = net2.crossover(&net1, &settings, MUTATION_RATE, rng); + assert_network_invariants(&b); - for j in 0..2 { - cache.add( - NeuronLocation::Output(j), - activation::builtin::sigmoid(cache.get(&hidden_loc) * 0.7), - ); + net1 = a; + net2 = b; } - } + }); +} - assert_eq!(cache.output(), [2.0688455, 2.0688455]); +#[cfg(feature = "serde")] +mod serde { + use super::rng_test; + use crate::*; + + #[test] + fn full_serde() { + rng_test(|rng| { + let net1 = NeuralNetwork::<10, 10>::new(rng); + + let mut buf = Vec::new(); + let writer = std::io::Cursor::new(&mut buf); + let mut serializer = serde_json::Serializer::new(writer); + + serde_path_to_error::serialize(&net1, &mut serializer).unwrap(); + let serialized = serde_json::to_string(&net1).unwrap(); + let net2: NeuralNetwork<10, 10> = serde_json::from_str(&serialized).unwrap(); + assert_eq!(net1, net2); + }); + } }