Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ class Activation
static std::unordered_map<std::string, Activation*> _activations;
};

// identity function activation
class ActivationIdentity : public nam::activations::Activation
{
public:
ActivationIdentity() = default;
~ActivationIdentity() = default;
// Inherit the default apply methods which do nothing
};

class ActivationTanh : public Activation
{
public:
Expand Down
170 changes: 170 additions & 0 deletions NAM/gating_activations.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#pragma once

#include <string>
#include <cmath> // expf
#include <unordered_map>
#include <Eigen/Dense>
#include <functional>
#include <stdexcept>
#include "activations.h"

namespace nam
{
namespace gating_activations
{

// Default linear activation (identity function)
class IdentityActivation : public nam::activations::Activation
{
public:
IdentityActivation() = default;
~IdentityActivation() = default;
// Inherit the default apply methods which do nothing (linear/identity)
};

class GatingActivation
{
public:
/**
* Constructor for GatingActivation
* @param input_act Activation function for input channels
* @param gating_act Activation function for gating channels
* @param input_channels Number of input channels (default: 1)
* @param gating_channels Number of gating channels (default: 1)
*/
GatingActivation(activations::Activation* input_act, activations::Activation* gating_act, int input_channels = 1)
: input_activation(input_act)
, gating_activation(gating_act)
, num_channels(input_channels)
{
assert(num_channels > 0);
}

~GatingActivation() = default;

/**
* Apply gating activation to input matrix
* @param input Input matrix with shape (input_channels + gating_channels) x num_samples
* @param output Output matrix with shape input_channels x num_samples
*/
void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output)
{
// Validate input dimensions (assert for real-time performance)
const int total_channels = 2 * num_channels;
assert(input.rows() == total_channels);
assert(output.rows() == num_channels);
assert(output.cols() == input.cols());

// Process column-by-column to ensure memory contiguity (important for column-major matrices)
const int num_samples = input.cols();
for (int i = 0; i < num_samples; i++)
{
// Apply activation to input channels
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
input_activation->apply(input_block);

// Apply activation to gating channels
Eigen::MatrixXf gating_block = input.block(num_channels, i, num_channels, 1);
gating_activation->apply(gating_block);

// Element-wise multiplication and store result
// For wavenet compatibility, we assume one-to-one mapping
output.block(0, i, num_channels, 1) = input_block.array() * gating_block.array();
}
}

/**
* Get the total number of input channels required
*/
int get_input_channels() const { return 2 * num_channels; }

/**
* Get the number of output channels
*/
int get_output_channels() const { return num_channels; }

private:
activations::Activation* input_activation;
activations::Activation* gating_activation;
int num_channels;
};

class BlendingActivation
{
public:
/**
* Constructor for BlendingActivation
* @param input_act Activation function for input channels
* @param blend_act Activation function for blending channels
* @param input_channels Number of input channels
*/
BlendingActivation(activations::Activation* input_act, activations::Activation* blend_act, int input_channels = 1)
: input_activation(input_act)
, blending_activation(blend_act)
, num_channels(input_channels)
{
if (num_channels <= 0)
{
throw std::invalid_argument("BlendingActivation: number of input channels must be positive");
}
// Initialize input buffer with correct size
// Note: current code copies column-by-column so we only need (num_channels, 1)
input_buffer.resize(num_channels, 1);
}

~BlendingActivation() = default;

/**
* Apply blending activation to input matrix
* @param input Input matrix with shape (input_channels + blend_channels) x num_samples
* @param output Output matrix with shape input_channels x num_samples
*/
void apply(Eigen::MatrixXf& input, Eigen::MatrixXf& output)
{
// Validate input dimensions (assert for real-time performance)
const int total_channels = num_channels * 2; // 2*channels in, channels out
assert(input.rows() == total_channels);
assert(output.rows() == num_channels);
assert(output.cols() == input.cols());

// Process column-by-column to ensure memory contiguity
const int num_samples = input.cols();
for (int i = 0; i < num_samples; i++)
{
// Store pre-activation input values in buffer
input_buffer = input.block(0, i, num_channels, 1);

// Apply activation to input channels
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
input_activation->apply(input_block);

// Apply activation to blend channels to compute alpha
Eigen::MatrixXf blend_block = input.block(num_channels, i, num_channels, 1);
blending_activation->apply(blend_block);

// Weighted blending: alpha * activated_input + (1 - alpha) * pre_activation_input
output.block(0, i, num_channels, 1) =
blend_block.array() * input_block.array() + (1.0f - blend_block.array()) * input_buffer.array();
}
}

/**
* Get the total number of input channels required
*/
int get_input_channels() const { return 2 * num_channels; }

/**
* Get the number of output channels
*/
int get_output_channels() const { return num_channels; }

private:
activations::Activation* input_activation;
activations::Activation* blending_activation;
int num_channels;
Eigen::MatrixXf input_buffer;
};


}; // namespace gating_activations
}; // namespace nam
31 changes: 30 additions & 1 deletion tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#include "test/test_get_dsp.cpp"
#include "test/test_wavenet.cpp"
#include "test/test_fast_lut.cpp"
#include "test/test_gating_activations.cpp"
#include "test/test_wavenet_gating_compatibility.cpp"
#include "test/test_blending_detailed.cpp"
#include "test/test_input_buffer_verification.cpp"

int main()
{
Expand All @@ -24,7 +28,7 @@ int main()
test_activations::TestPReLU::test_core_function();
test_activations::TestPReLU::test_per_channel_behavior();
// This is enforced by an assert so it doesn't need to be tested
//test_activations::TestPReLU::test_wrong_number_of_channels();
// test_activations::TestPReLU::test_wrong_number_of_channels();

test_dsp::test_construct();
test_dsp::test_get_input_level();
Expand All @@ -44,6 +48,31 @@ int main()
test_lut::TestFastLUT::test_sigmoid();
test_lut::TestFastLUT::test_tanh();

// Gating activations tests
test_gating_activations::TestGatingActivation::test_basic_functionality();
test_gating_activations::TestGatingActivation::test_with_custom_activations();
// test_gating_activations::TestGatingActivation::test_error_handling();

// Wavenet gating compatibility tests
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_wavenet_style_gating();
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_column_by_column_processing();
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_memory_contiguity();
test_wavenet_gating_compatibility::TestWavenetGatingCompatibility::test_multiple_channels();

test_gating_activations::TestBlendingActivation::test_basic_functionality();
test_gating_activations::TestBlendingActivation::test_blending_behavior();
test_gating_activations::TestBlendingActivation::test_with_custom_activations();
// test_gating_activations::TestBlendingActivation::test_error_handling();
test_gating_activations::TestBlendingActivation::test_edge_cases();

// Detailed blending tests
test_blending_detailed::TestBlendingDetailed::test_blending_with_different_activations();
test_blending_detailed::TestBlendingDetailed::test_input_buffer_usage();

// Input buffer verification tests
test_input_buffer_verification::TestInputBufferVerification::test_buffer_stores_pre_activation_values();
test_input_buffer_verification::TestInputBufferVerification::test_buffer_with_different_activations();

std::cout << "Success!" << std::endl;
return 0;
}
115 changes: 115 additions & 0 deletions tools/test/test_blending_detailed.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Detailed test for BlendingActivation behavior

#include <cassert>
#include <string>
#include <vector>
#include <cmath>
#include <iostream>

#include "NAM/gating_activations.h"
#include "NAM/activations.h"

namespace test_blending_detailed
{

class TestBlendingDetailed
{
public:
static void test_blending_with_different_activations()
{
// Test case: 2 input channels, so we need 4 total input channels (2*channels in)
Eigen::MatrixXf input(4, 2); // 4 rows (2 input + 2 blending), 2 samples
input << 1.0f, 2.0f, // Input channel 1
3.0f, 4.0f, // Input channel 2
0.5f, 0.8f, // Blending channel 1
0.3f, 0.6f; // Blending channel 2

Eigen::MatrixXf output(2, 2); // 2 output channels, 2 samples

// Test with default (linear) activations
nam::activations::ActivationIdentity identity_act;
nam::activations::ActivationIdentity identity_blend_act;
nam::gating_activations::BlendingActivation blending_act(&identity_act, &identity_blend_act, 2);
blending_act.apply(input, output);

std::cout << "Blending with linear activations:" << std::endl;
std::cout << "Input:" << std::endl << input << std::endl;
std::cout << "Output:" << std::endl << output << std::endl;

// With linear activations:
// alpha = blend_input (since linear activation does nothing)
// output = alpha * input + (1 - alpha) * input = input
// So output should equal the input channels after activation (which is the same as input)
assert(fabs(output(0, 0) - 1.0f) < 1e-6);
assert(fabs(output(1, 0) - 3.0f) < 1e-6);
assert(fabs(output(0, 1) - 2.0f) < 1e-6);
assert(fabs(output(1, 1) - 4.0f) < 1e-6);

// Test with sigmoid blending activation
nam::activations::Activation* sigmoid_act = nam::activations::Activation::get_activation("Sigmoid");
nam::gating_activations::BlendingActivation blending_act_sigmoid(&identity_act, sigmoid_act, 2);

Eigen::MatrixXf output_sigmoid(2, 2);
blending_act_sigmoid.apply(input, output_sigmoid);

std::cout << "Blending with sigmoid blending activation:" << std::endl;
std::cout << "Output:" << std::endl << output_sigmoid << std::endl;

// With sigmoid blending, alpha values should be between 0 and 1
// For blend input 0.5, sigmoid(0.5) ≈ 0.622
// For blend input 0.8, sigmoid(0.8) ≈ 0.690
// For blend input 0.3, sigmoid(0.3) ≈ 0.574
// For blend input 0.6, sigmoid(0.6) ≈ 0.646

float alpha0_0 = 1.0f / (1.0f + expf(-0.5f)); // sigmoid(0.5)
float alpha1_0 = 1.0f / (1.0f + expf(-0.8f)); // sigmoid(0.8)
float alpha0_1 = 1.0f / (1.0f + expf(-0.3f)); // sigmoid(0.3)
float alpha1_1 = 1.0f / (1.0f + expf(-0.6f)); // sigmoid(0.6)

// Expected output: alpha * activated_input + (1 - alpha) * pre_activation_input
// Since input activation is linear, activated_input = pre_activation_input = input
// So output = alpha * input + (1 - alpha) * input = input
// This should be the same as with linear activations
assert(fabs(output_sigmoid(0, 0) - 1.0f) < 1e-6);
assert(fabs(output_sigmoid(1, 0) - 3.0f) < 1e-6);
assert(fabs(output_sigmoid(0, 1) - 2.0f) < 1e-6);
assert(fabs(output_sigmoid(1, 1) - 4.0f) < 1e-6);

std::cout << "Blending detailed test passed" << std::endl;
}

static void test_input_buffer_usage()
{
// Test that the input buffer is correctly storing pre-activation values
Eigen::MatrixXf input(2, 1);
input << 2.0f, 0.5f;

Eigen::MatrixXf output(1, 1);

// Test with ReLU activation on input (which will change values < 0 to 0)
nam::activations::ActivationReLU relu_act;
nam::activations::ActivationIdentity identity_act;
nam::gating_activations::BlendingActivation blending_act(&relu_act, &identity_act, 1);

blending_act.apply(input, output);

// With input=2.0, ReLU(2.0)=2.0, blend=0.5
// output = 0.5 * 2.0 + (1 - 0.5) * 2.0 = 0.5 * 2.0 + 0.5 * 2.0 = 2.0
assert(fabs(output(0, 0) - 2.0f) < 1e-6);

// Test with negative input value
Eigen::MatrixXf input2(2, 1);
input2 << -1.0f, 0.5f;

Eigen::MatrixXf output2(1, 1);
blending_act.apply(input2, output2);

// With input=-1.0, ReLU(-1.0)=0.0, blend=0.5
// output = 0.5 * 0.0 + (1 - 0.5) * (-1.0) = 0.0 + 0.5 * (-1.0) = -0.5
assert(fabs(output2(0, 0) - (-0.5f)) < 1e-6);

std::cout << "Input buffer usage test passed" << std::endl;
}
};

}; // namespace test_blending_detailed
Loading