diff --git a/NAM/activations.h b/NAM/activations.h index 3e77614..4429964 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -111,6 +111,15 @@ class Activation static std::unordered_map _activations; }; +// identity function activation +class ActivationIdentity : public nam::activations::Activation +{ +public: + ActivationIdentity() = default; + ~ActivationIdentity() = default; + // Inherit the default apply methods which do nothing +}; + class ActivationTanh : public Activation { public: diff --git a/NAM/gating_activations.h b/NAM/gating_activations.h new file mode 100644 index 0000000..3436cdd --- /dev/null +++ b/NAM/gating_activations.h @@ -0,0 +1,170 @@ +#pragma once + +#include +#include // expf +#include +#include +#include +#include +#include "activations.h" + +namespace nam +{ +namespace gating_activations +{ + +// Default linear activation (identity function) +class IdentityActivation : public nam::activations::Activation +{ +public: + IdentityActivation() = default; + ~IdentityActivation() = default; + // Inherit the default apply methods which do nothing (linear/identity) +}; + +class GatingActivation +{ +public: + /** + * Constructor for GatingActivation + * @param input_act Activation function for input channels + * @param gating_act Activation function for gating channels + * @param input_channels Number of input channels (default: 1) + * @param gating_channels Number of gating channels (default: 1) + */ + GatingActivation(activations::Activation* input_act, activations::Activation* gating_act, int input_channels = 1) + : input_activation(input_act) + , gating_activation(gating_act) + , num_channels(input_channels) + { + assert(num_channels > 0); + } + + ~GatingActivation() = default; + + /** + * Apply gating activation to input matrix + * @param input Input matrix with shape (input_channels + gating_channels) x num_samples + * @param output Output matrix with shape input_channels x num_samples + */ + void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output) + { + // Validate input dimensions (assert for real-time performance) + const int total_channels = 2 * num_channels; + assert(input.rows() == total_channels); + assert(output.rows() == num_channels); + assert(output.cols() == input.cols()); + + // Process column-by-column to ensure memory contiguity (important for column-major matrices) + const int num_samples = input.cols(); + for (int i = 0; i < num_samples; i++) + { + // Apply activation to input channels + Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1); + input_activation->apply(input_block); + + // Apply activation to gating channels + Eigen::MatrixXf gating_block = input.block(num_channels, i, num_channels, 1); + gating_activation->apply(gating_block); + + // Element-wise multiplication and store result + // For wavenet compatibility, we assume one-to-one mapping + output.block(0, i, num_channels, 1) = input_block.array() * gating_block.array(); + } + } + + /** + * Get the total number of input channels required + */ + int get_input_channels() const { return 2 * num_channels; } + + /** + * Get the number of output channels + */ + int get_output_channels() const { return num_channels; } + +private: + activations::Activation* input_activation; + activations::Activation* gating_activation; + int num_channels; +}; + +class BlendingActivation +{ +public: + /** + * Constructor for BlendingActivation + * @param input_act Activation function for input channels + * @param blend_act Activation function for blending channels + * @param input_channels Number of input channels + */ + BlendingActivation(activations::Activation* input_act, activations::Activation* blend_act, int input_channels = 1) + : input_activation(input_act) + , blending_activation(blend_act) + , num_channels(input_channels) + { + if (num_channels <= 0) + { + throw std::invalid_argument("BlendingActivation: number of input channels must be positive"); + } + // Initialize input buffer with correct size + // Note: current code copies column-by-column so we only need (num_channels, 1) + input_buffer.resize(num_channels, 1); + } + + ~BlendingActivation() = default; + + /** + * Apply blending activation to input matrix + * @param input Input matrix with shape (input_channels + blend_channels) x num_samples + * @param output Output matrix with shape input_channels x num_samples + */ + void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output) + { + // Validate input dimensions (assert for real-time performance) + const int total_channels = num_channels * 2; // 2*channels in, channels out + assert(input.rows() == total_channels); + assert(output.rows() == num_channels); + assert(output.cols() == input.cols()); + + // Process column-by-column to ensure memory contiguity + const int num_samples = input.cols(); + for (int i = 0; i < num_samples; i++) + { + // Store pre-activation input values in buffer + input_buffer = input.block(0, i, num_channels, 1); + + // Apply activation to input channels + Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1); + input_activation->apply(input_block); + + // Apply activation to blend channels to compute alpha + Eigen::MatrixXf blend_block = input.block(num_channels, i, num_channels, 1); + blending_activation->apply(blend_block); + + // Weighted blending: alpha * activated_input + (1 - alpha) * pre_activation_input + output.block(0, i, num_channels, 1) = + blend_block.array() * input_block.array() + (1.0f - blend_block.array()) * input_buffer.array(); + } + } + + /** + * Get the total number of input channels required + */ + int get_input_channels() const { return 2 * num_channels; } + + /** + * Get the number of output channels + */ + int get_output_channels() const { return num_channels; } + +private: + activations::Activation* input_activation; + activations::Activation* blending_activation; + int num_channels; + Eigen::MatrixXf input_buffer; +}; + + +}; // namespace gating_activations +}; // namespace nam diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 928bd9f..2aa66ec 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -7,6 +7,10 @@ #include "test/test_get_dsp.cpp" #include "test/test_wavenet.cpp" #include "test/test_fast_lut.cpp" +#include "test/test_gating_activations.cpp" +#include "test/test_wavenet_gating_compatibility.cpp" +#include "test/test_blending_detailed.cpp" +#include "test/test_input_buffer_verification.cpp" int main() { @@ -24,7 +28,7 @@ int main() test_activations::TestPReLU::test_core_function(); test_activations::TestPReLU::test_per_channel_behavior(); // This is enforced by an assert so it doesn't need to be tested - //test_activations::TestPReLU::test_wrong_number_of_channels(); + // test_activations::TestPReLU::test_wrong_number_of_channels(); test_dsp::test_construct(); test_dsp::test_get_input_level(); @@ -44,6 +48,31 @@ int main() test_lut::TestFastLUT::test_sigmoid(); test_lut::TestFastLUT::test_tanh(); + // Gating activations tests + test_gating_activations::TestGatingActivation::test_basic_functionality(); + test_gating_activations::TestGatingActivation::test_with_custom_activations(); + // test_gating_activations::TestGatingActivation::test_error_handling(); + + // Wavenet gating compatibility tests + test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_wavenet_style_gating(); + test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_column_by_column_processing(); + test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_memory_contiguity(); + test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_multiple_channels(); + + test_gating_activations::TestBlendingActivation::test_basic_functionality(); + test_gating_activations::TestBlendingActivation::test_blending_behavior(); + test_gating_activations::TestBlendingActivation::test_with_custom_activations(); + // test_gating_activations::TestBlendingActivation::test_error_handling(); + test_gating_activations::TestBlendingActivation::test_edge_cases(); + + // Detailed blending tests + test_blending_detailed::TestBlendingDetailed::test_blending_with_different_activations(); + test_blending_detailed::TestBlendingDetailed::test_input_buffer_usage(); + + // Input buffer verification tests + test_input_buffer_verification::TestInputBufferVerification::test_buffer_stores_pre_activation_values(); + test_input_buffer_verification::TestInputBufferVerification::test_buffer_with_different_activations(); + std::cout << "Success!" << std::endl; return 0; } diff --git a/tools/test/test_blending_detailed.cpp b/tools/test/test_blending_detailed.cpp new file mode 100644 index 0000000..b526ae9 --- /dev/null +++ b/tools/test/test_blending_detailed.cpp @@ -0,0 +1,115 @@ +// Detailed test for BlendingActivation behavior + +#include +#include +#include +#include +#include + +#include "NAM/gating_activations.h" +#include "NAM/activations.h" + +namespace test_blending_detailed +{ + +class TestBlendingDetailed +{ +public: + static void test_blending_with_different_activations() + { + // Test case: 2 input channels, so we need 4 total input channels (2*channels in) + Eigen::MatrixXf input(4, 2); // 4 rows (2 input + 2 blending), 2 samples + input << 1.0f, 2.0f, // Input channel 1 + 3.0f, 4.0f, // Input channel 2 + 0.5f, 0.8f, // Blending channel 1 + 0.3f, 0.6f; // Blending channel 2 + + Eigen::MatrixXf output(2, 2); // 2 output channels, 2 samples + + // Test with default (linear) activations + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationIdentity identity_blend_act; + nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 2); + blending_act.apply(input, output); + + std::cout << "Blending with linear activations:" << std::endl; + std::cout << "Input:" << std::endl << input << std::endl; + std::cout << "Output:" << std::endl << output << std::endl; + + // With linear activations: + // alpha = blend_input (since linear activation does nothing) + // output = alpha * input + (1 - alpha) * input = input + // So output should equal the input channels after activation (which is the same as input) + assert(fabs(output(0, 0) - 1.0f) < 1e-6); + assert(fabs(output(1, 0) - 3.0f) < 1e-6); + assert(fabs(output(0, 1) - 2.0f) < 1e-6); + assert(fabs(output(1, 1) - 4.0f) < 1e-6); + + // Test with sigmoid blending activation + nam::activations::Activation* sigmoid_act = nam::activations::Activation::get_activation("Sigmoid"); + nam::gating_activations::BlendingActivation blending_act_sigmoid(&identity_act, sigmoid_act, 2); + + Eigen::MatrixXf output_sigmoid(2, 2); + blending_act_sigmoid.apply(input, output_sigmoid); + + std::cout << "Blending with sigmoid blending activation:" << std::endl; + std::cout << "Output:" << std::endl << output_sigmoid << std::endl; + + // With sigmoid blending, alpha values should be between 0 and 1 + // For blend input 0.5, sigmoid(0.5) ≈ 0.622 + // For blend input 0.8, sigmoid(0.8) ≈ 0.690 + // For blend input 0.3, sigmoid(0.3) ≈ 0.574 + // For blend input 0.6, sigmoid(0.6) ≈ 0.646 + + float alpha0_0 = 1.0f / (1.0f + expf(-0.5f)); // sigmoid(0.5) + float alpha1_0 = 1.0f / (1.0f + expf(-0.8f)); // sigmoid(0.8) + float alpha0_1 = 1.0f / (1.0f + expf(-0.3f)); // sigmoid(0.3) + float alpha1_1 = 1.0f / (1.0f + expf(-0.6f)); // sigmoid(0.6) + + // Expected output: alpha * activated_input + (1 - alpha) * pre_activation_input + // Since input activation is linear, activated_input = pre_activation_input = input + // So output = alpha * input + (1 - alpha) * input = input + // This should be the same as with linear activations + assert(fabs(output_sigmoid(0, 0) - 1.0f) < 1e-6); + assert(fabs(output_sigmoid(1, 0) - 3.0f) < 1e-6); + assert(fabs(output_sigmoid(0, 1) - 2.0f) < 1e-6); + assert(fabs(output_sigmoid(1, 1) - 4.0f) < 1e-6); + + std::cout << "Blending detailed test passed" << std::endl; + } + + static void test_input_buffer_usage() + { + // Test that the input buffer is correctly storing pre-activation values + Eigen::MatrixXf input(2, 1); + input << 2.0f, 0.5f; + + Eigen::MatrixXf output(1, 1); + + // Test with ReLU activation on input (which will change values < 0 to 0) + nam::activations::ActivationReLU relu_act; + nam::activations::ActivationIdentity identity_act; + nam::gating_activations::BlendingActivation blending_act(&relu_act, &identity_act, 1); + + blending_act.apply(input, output); + + // With input=2.0, ReLU(2.0)=2.0, blend=0.5 + // output = 0.5 * 2.0 + (1 - 0.5) * 2.0 = 0.5 * 2.0 + 0.5 * 2.0 = 2.0 + assert(fabs(output(0, 0) - 2.0f) < 1e-6); + + // Test with negative input value + Eigen::MatrixXf input2(2, 1); + input2 << -1.0f, 0.5f; + + Eigen::MatrixXf output2(1, 1); + blending_act.apply(input2, output2); + + // With input=-1.0, ReLU(-1.0)=0.0, blend=0.5 + // output = 0.5 * 0.0 + (1 - 0.5) * (-1.0) = 0.0 + 0.5 * (-1.0) = -0.5 + assert(fabs(output2(0, 0) - (-0.5f)) < 1e-6); + + std::cout << "Input buffer usage test passed" << std::endl; + } +}; + +}; // namespace test_blending_detailed \ No newline at end of file diff --git a/tools/test/test_gating_activations.cpp b/tools/test/test_gating_activations.cpp new file mode 100644 index 0000000..43e414f --- /dev/null +++ b/tools/test/test_gating_activations.cpp @@ -0,0 +1,217 @@ +// Tests for gating activation functions + +#include +#include +#include +#include +#include + +#include "NAM/gating_activations.h" +#include "NAM/activations.h" + +namespace test_gating_activations +{ + +class TestGatingActivation +{ +public: + static void test_basic_functionality() + { + // Create test input data (2 rows, 3 columns) + Eigen::MatrixXf input(2, 3); + input << 1.0f, -1.0f, 0.0f, 0.5f, 0.8f, 1.0f; + + Eigen::MatrixXf output(1, 3); + + // Create gating activation with default activations (1 input channel, 1 gating channel) + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, 1); + + // Apply the activation + gating_act.apply(input, output); + + // Basic checks + assert(output.rows() == 1); + assert(output.cols() == 3); + + // The output should be element-wise multiplication of the two rows + // after applying activations + std::cout << "GatingActivation basic test passed" << std::endl; + } + + static void test_with_custom_activations() + { + // Create custom activations + nam::activations::ActivationLeakyReLU leaky_relu(0.01f); + nam::activations::ActivationLeakyReLU leaky_relu2(0.05f); + + // Create test input data + Eigen::MatrixXf input(2, 2); + input << -1.0f, 1.0f, -2.0f, 0.5f; + + Eigen::MatrixXf output(1, 2); + + // Create gating activation with custom activations + nam::gating_activations::GatingActivation gating_act(&leaky_relu, &leaky_relu2, 1); + + // Apply the activation + gating_act.apply(input, output); + + // Verify dimensions + assert(output.rows() == 1); + assert(output.cols() == 2); + + std::cout << "GatingActivation custom activations test passed" << std::endl; + } + + static void test_error_handling() + { + // Test with insufficient rows - should assert + // In real-time code, we use asserts instead of exceptions for performance + // These tests would normally crash the program due to asserts + // In production, these conditions should never occur if the code is used correctly + } +}; + +class TestBlendingActivation +{ +public: + static void test_basic_functionality() + { + // Create test input data (2 rows, 3 columns) + Eigen::MatrixXf input(2, 3); + input << 1.0f, -1.0f, 0.0f, 0.5f, 0.8f, 1.0f; + + Eigen::MatrixXf output(1, 3); + + // Create blending activation (1 input channel) + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationIdentity identity_blend_act; + nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 1); + + // Apply the activation + blending_act.apply(input, output); + + // Basic checks + assert(output.rows() == 1); + assert(output.cols() == 3); + + std::cout << "BlendingActivation basic test passed" << std::endl; + } + + static void test_blending_behavior() + { + // Test blending with different activation functions + // Create test input data (2 rows, 2 columns) + Eigen::MatrixXf input(2, 2); + input << 1.0f, -1.0f, 0.5f, 0.8f; + + Eigen::MatrixXf output(1, 2); + + // Test with default (linear) activations + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationIdentity identity_blend_act; + nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 1); + blending_act.apply(input, output); + + // With linear activations, blending should be: + // alpha = blend_input (since linear activation does nothing) + // output = alpha * input + (1 - alpha) * input = input + // So output should equal the first row (input after activation) + assert(fabs(output(0, 0) - 1.0f) < 1e-6); + assert(fabs(output(0, 1) - (-1.0f)) < 1e-6); + + // Test with sigmoid blending activation + nam::activations::Activation* sigmoid_act = nam::activations::Activation::get_activation("Sigmoid"); + nam::gating_activations::BlendingActivation blending_act2(&identity_act, sigmoid_act, 1); + blending_act2.apply(input, output); + + // With sigmoid blending, alpha values should be between 0 and 1 + // For input 0.5, sigmoid(0.5) ≈ 0.622 + // For input 0.8, sigmoid(0.8) ≈ 0.690 + float alpha0 = 1.0f / (1.0f + expf(-0.5f)); // sigmoid(0.5) + float alpha1 = 1.0f / (1.0f + expf(-0.8f)); // sigmoid(0.8) + + // Expected output: alpha * activated_input + (1 - alpha) * pre_activation_input + // Since input activation is linear, activated_input = pre_activation_input = input + // So output = alpha * input + (1 - alpha) * input = input + // This is the same as with linear activations + assert(fabs(output(0, 0) - 1.0f) < 1e-6); + assert(fabs(output(0, 1) - (-1.0f)) < 1e-6); + + std::cout << "BlendingActivation blending behavior test passed" << std::endl; + } + + static void test_with_custom_activations() + { + // Create custom activations + nam::activations::ActivationLeakyReLU leaky_relu(0.01f); + nam::activations::ActivationLeakyReLU leaky_relu2(0.05f); + + // Create test input data + Eigen::MatrixXf input(2, 2); + input << -1.0f, 1.0f, -2.0f, 0.5f; + + Eigen::MatrixXf output(1, 2); + + // Create blending activation with custom activations + nam::gating_activations::BlendingActivation blending_act(&leaky_relu, &leaky_relu2, 1); + + // Apply the activation + blending_act.apply(input, output); + + // Verify dimensions + assert(output.rows() == 1); + assert(output.cols() == 2); + + std::cout << "BlendingActivation custom activations test passed" << std::endl; + } + + static void test_error_handling() + { + // Test with insufficient rows - should assert + Eigen::MatrixXf input(1, 2); // Only 1 row + Eigen::MatrixXf output(1, 2); + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationIdentity identity_blend_act; + nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 1); + + // This should trigger an assert and terminate the program + // We can't easily test asserts in a unit test framework without special handling + // For real-time code, we rely on the asserts to catch issues during development + + // Test with invalid number of channels - should assert in constructor + // These tests would normally crash the program due to asserts + // In production, these conditions should never occur if the code is used correctly + } + + static void test_edge_cases() + { + // Test with zero input + Eigen::MatrixXf input(2, 1); + input << 0.0f, 0.0f; + + Eigen::MatrixXf output(1, 1); + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationIdentity identity_blend_act; + nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 1); + blending_act.apply(input, output); + + assert(fabs(output(0, 0) - 0.0f) < 1e-6); + + // Test with large values + Eigen::MatrixXf input2(2, 1); + input2 << 1000.0f, -1000.0f; + + blending_act.apply(input2, output); + + // Should handle large values without issues + assert(output.rows() == 1); + assert(output.cols() == 1); + } +}; + +}; // namespace test_gating_activations diff --git a/tools/test/test_input_buffer_verification.cpp b/tools/test/test_input_buffer_verification.cpp new file mode 100644 index 0000000..d7d280e --- /dev/null +++ b/tools/test/test_input_buffer_verification.cpp @@ -0,0 +1,90 @@ +// Test to verify that input buffer correctly stores pre-activation values + +#include +#include +#include +#include +#include + +#include "NAM/gating_activations.h" +#include "NAM/activations.h" + +namespace test_input_buffer_verification +{ + +class TestInputBufferVerification +{ +public: + static void test_buffer_stores_pre_activation_values() + { + // Create a test case where input activation changes the values + Eigen::MatrixXf input(2, 1); + input << -2.0f, 0.5f; // Negative input value + + Eigen::MatrixXf output(1, 1); + + // Use ReLU activation which will set negative values to 0 + nam::activations::ActivationReLU relu_act; + nam::activations::ActivationIdentity identity_act; + nam::gating_activations::BlendingActivation blending_act(&relu_act, &identity_act, 1); + + // Apply the activation + blending_act.apply(input, output); + + std::cout << "Input buffer verification test:" << std::endl; + std::cout << "Input: " << input(0, 0) << " (will be modified by ReLU)" << std::endl; + std::cout << "Blend value: " << input(1, 0) << std::endl; + std::cout << "Output: " << output(0, 0) << std::endl; + + // Expected behavior: + // 1. Store pre-activation input in buffer: input_buffer = -2.0f + // 2. Apply ReLU to input: activated_input = max(-2.0f, 0) = 0.0f + // 3. Apply linear activation to blend: alpha = 0.5f (no change) + // 4. Compute output: alpha * activated_input + (1 - alpha) * input_buffer + // = 0.5f * 0.0f + 0.5f * (-2.0f) = -1.0f + + float expected = 0.5f * 0.0f + 0.5f * (-2.0f); // = -1.0f + assert(fabs(output(0, 0) - expected) < 1e-6); + + std::cout << "Expected: " << expected << std::endl; + std::cout << "Input buffer verification test passed!" << std::endl; + } + + static void test_buffer_with_different_activations() + { + // Test with LeakyReLU which modifies negative values differently + Eigen::MatrixXf input(2, 1); + input << -1.0f, 0.8f; + + Eigen::MatrixXf output(1, 1); + + // Use LeakyReLU with slope 0.1 + nam::activations::ActivationLeakyReLU leaky_relu(0.1f); + nam::activations::ActivationIdentity identity_act; + nam::gating_activations::BlendingActivation blending_act(&leaky_relu, &identity_act, 1); + + blending_act.apply(input, output); + + std::cout << "LeakyReLU buffer test:" << std::endl; + std::cout << "Input: " << input(0, 0) << std::endl; + std::cout << "Blend value: " << input(1, 0) << std::endl; + std::cout << "Output: " << output(0, 0) << std::endl; + + // Expected behavior: + // 1. Store pre-activation input in buffer: input_buffer = -1.0f + // 2. Apply LeakyReLU: activated_input = (-1.0f > 0) ? -1.0f : 0.1f * -1.0f = -0.1f + // 3. Apply linear activation to blend: alpha = 0.8f + // 4. Compute output: alpha * activated_input + (1 - alpha) * input_buffer + // = 0.8f * (-0.1f) + 0.2f * (-1.0f) = -0.08f - 0.2f = -0.28f + + float activated_input = (-1.0f > 0) ? -1.0f : 0.1f * -1.0f; // = -0.1f + float expected = 0.8f * activated_input + 0.2f * (-1.0f); // = -0.28f + + assert(fabs(output(0, 0) - expected) < 1e-6); + + std::cout << "Expected: " << expected << std::endl; + std::cout << "LeakyReLU buffer test passed!" << std::endl; + } +}; + +}; // namespace test_input_buffer_verification \ No newline at end of file diff --git a/tools/test/test_wavenet_gating_compatibility.cpp b/tools/test/test_wavenet_gating_compatibility.cpp new file mode 100644 index 0000000..44af69b --- /dev/null +++ b/tools/test/test_wavenet_gating_compatibility.cpp @@ -0,0 +1,189 @@ +// Test to verify that our gating implementation matches the wavenet behavior + +#include +#include +#include +#include +#include + +#include "NAM/gating_activations.h" +#include "NAM/activations.h" + +namespace test_wavenet_gating_compatibility +{ + +class TestWavenetGatingCompatibility +{ +public: + static void test_wavenet_style_gating() + { + // Simulate wavenet scenario: 2 channels (input + gating), multiple samples + const int channels = 2; + const int num_samples = 3; + + // Create input matrix similar to wavenet's _z matrix + // First 'channels' rows are input, next 'channels' rows are gating + Eigen::MatrixXf input(2 * channels, num_samples); + input << 1.0f, -0.5f, 0.2f, // Input channel 1 + 0.3f, 0.1f, -0.4f, // Input channel 2 + 0.8f, 0.6f, 0.9f, // Gating channel 1 + 0.4f, 0.2f, 0.7f; // Gating channel 2 + + Eigen::MatrixXf output(channels, num_samples); + + // Create gating activation that matches wavenet behavior + // Wavenet uses: input activation (default/linear) and sigmoid for gating + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, channels); + + // Apply the activation + gating_act.apply(input, output); + + // Verify dimensions + assert(output.rows() == channels); + assert(output.cols() == num_samples); + + // Verify that the output is the element-wise product of input and gating channels + // after applying activations + for (int c = 0; c < channels; c++) + { + for (int s = 0; s < num_samples; s++) + { + // Input channel value (no activation applied - linear) + float input_val = input(c, s); + + // Gating channel value (sigmoid activation applied) + float gating_val = input(channels + c, s); + float sigmoid_gating = 1.0f / (1.0f + expf(-gating_val)); + + // Expected output + float expected = input_val * sigmoid_gating; + + // Check if they match + if (fabs(output(c, s) - expected) > 1e-6) + { + std::cerr << "Mismatch at channel " << c << ", sample " << s << ": expected " << expected << ", got " + << output(c, s) << std::endl; + assert(false); + } + } + } + + std::cout << "Wavenet gating compatibility test passed" << std::endl; + } + + static void test_column_by_column_processing() + { + // Test that our implementation processes column-by-column like wavenet + const int channels = 1; + const int num_samples = 4; + + Eigen::MatrixXf input(2, num_samples); + input << 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f; + + Eigen::MatrixXf output(channels, num_samples); + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, channels); + gating_act.apply(input, output); + + // Verify each column was processed independently + for (int s = 0; s < num_samples; s++) + { + float input_val = input(0, s); + float gating_val = input(1, s); + float sigmoid_gating = 1.0f / (1.0f + expf(-gating_val)); + float expected = input_val * sigmoid_gating; + + assert(fabs(output(0, s) - expected) < 1e-6); + } + + std::cout << "Column-by-column processing test passed" << std::endl; + } + + static void test_memory_contiguity() + { + // Test that our implementation handles memory contiguity correctly + // This is important for column-major matrices + const int channels = 3; + const int num_samples = 2; + + Eigen::MatrixXf input(2 * channels, num_samples); + // Fill with some values + for (int i = 0; i < 2 * channels; i++) + { + for (int j = 0; j < num_samples; j++) + { + input(i, j) = static_cast(i * num_samples + j + 1); + } + } + + Eigen::MatrixXf output(channels, num_samples); + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, channels); + + // This should not crash or produce incorrect results due to memory contiguity issues + gating_act.apply(input, output); + + // Verify the results are correct + for (int c = 0; c < channels; c++) + { + for (int s = 0; s < num_samples; s++) + { + float input_val = input(c, s); + float gating_val = input(channels + c, s); + float sigmoid_gating = 1.0f / (1.0f + expf(-gating_val)); + float expected = input_val * sigmoid_gating; + + assert(fabs(output(c, s) - expected) < 1e-6); + } + } + + std::cout << "Memory contiguity test passed" << std::endl; + } + + static void test_multiple_channels() + { + // Test with multiple equal input and gating channels (wavenet style) + const int channels = 2; + const int num_samples = 2; + + Eigen::MatrixXf input(2 * channels, num_samples); + input << 1.0f, 2.0f, // Input channels + 3.0f, 4.0f, 5.0f, 6.0f, // Gating channels + 7.0f, 8.0f; + + Eigen::MatrixXf output(channels, num_samples); + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating_act(&identity_act, &sigmoid_act, channels); + gating_act.apply(input, output); + + // Verify dimensions + assert(output.rows() == channels); + assert(output.cols() == num_samples); + + // Verify that each input channel is multiplied by corresponding gating channel + for (int c = 0; c < channels; c++) + { + for (int s = 0; s < num_samples; s++) + { + float input_val = input(c, s); + float gating_val = input(channels + c, s); + float sigmoid_gating = 1.0f / (1.0f + expf(-gating_val)); + float expected = input_val * sigmoid_gating; + + assert(fabs(output(c, s) - expected) < 1e-6); + } + } + + std::cout << "Multiple channels test passed" << std::endl; + } +}; + +}; // namespace test_wavenet_gating_compatibility