diff --git a/.gitignore b/.gitignore index 259148f..8604b38 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,5 @@ *.exe *.out *.app + +.vscode/ diff --git a/NAM/activations.h b/NAM/activations.h index 4429964..2813c35 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include // expf #include diff --git a/NAM/conv1d.cpp b/NAM/conv1d.cpp new file mode 100644 index 0000000..6495506 --- /dev/null +++ b/NAM/conv1d.cpp @@ -0,0 +1,132 @@ +#include "conv1d.h" + +namespace nam +{ +// Conv1D ===================================================================== + +void Conv1D::set_weights_(std::vector::iterator& weights) +{ + if (this->_weight.size() > 0) + { + const long out_channels = this->_weight[0].rows(); + const long in_channels = this->_weight[0].cols(); + // Crazy ordering because that's how it gets flattened. + for (auto i = 0; i < out_channels; i++) + for (auto j = 0; j < in_channels; j++) + for (size_t k = 0; k < this->_weight.size(); k++) + this->_weight[k](i, j) = *(weights++); + } + for (long i = 0; i < this->_bias.size(); i++) + this->_bias(i) = *(weights++); +} + +void Conv1D::set_size_(const int in_channels, const int out_channels, const int kernel_size, const bool do_bias, + const int _dilation) +{ + this->_weight.resize(kernel_size); + for (size_t i = 0; i < this->_weight.size(); i++) + this->_weight[i].resize(out_channels, + in_channels); // y = Ax, input array (C,L) + if (do_bias) + this->_bias.resize(out_channels); + else + this->_bias.resize(0); + this->_dilation = _dilation; +} + +void Conv1D::set_size_and_weights_(const int in_channels, const int out_channels, const int kernel_size, + const int _dilation, const bool do_bias, std::vector::iterator& weights) +{ + this->set_size_(in_channels, out_channels, kernel_size, do_bias, _dilation); + this->set_weights_(weights); +} + +void Conv1D::SetMaxBufferSize(const int maxBufferSize) +{ + _max_buffer_size = maxBufferSize; + + // Calculate receptive field (maximum lookback needed) + const long kernel_size = get_kernel_size(); + const long dilation = get_dilation(); + const long receptive_field = kernel_size > 0 ? (kernel_size - 1) * dilation : 0; + + const long in_channels = get_in_channels(); + + // Initialize input ring buffer + // Set max lookback before Reset so that Reset() can use it to calculate storage size + // Reset() will calculate storage size as: 2 * max_lookback + max_buffer_size + _input_buffer.SetMaxLookback(receptive_field); + _input_buffer.Reset(in_channels, maxBufferSize); + + // Pre-allocate output matrix + const long out_channels = get_out_channels(); + _output.resize(out_channels, maxBufferSize); + _output.setZero(); +} + + +void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) +{ + // Write input to ring buffer + _input_buffer.Write(input, num_frames); + + // Zero output before processing + _output.leftCols(num_frames).setZero(); + + // Process from ring buffer with dilation lookback + // After Write(), data is at positions [_write_pos, _write_pos+num_frames-1] + // For kernel tap k with offset, we need to read from _write_pos + offset + // The offset is negative (looking back), so _write_pos + offset reads from earlier positions + // The original process_() reads: input.middleCols(i_start + offset, ncols) + // where i_start is the current position and offset is negative for lookback + for (size_t k = 0; k < this->_weight.size(); k++) + { + const long offset = this->_dilation * (k + 1 - (long)this->_weight.size()); + // Offset is negative (looking back) + // Read from position: _write_pos + offset + // Since offset is negative, we compute lookback = -offset to read from _write_pos - lookback + const long lookback = -offset; + + // Read num_frames starting from write_pos + offset (which is write_pos - lookback) + auto input_block = _input_buffer.Read(num_frames, lookback); + + // Perform convolution: output += weight[k] * input_block + _output.leftCols(num_frames).noalias() += this->_weight[k] * input_block; + } + + // Add bias if present + if (this->_bias.size() > 0) + { + _output.leftCols(num_frames).colwise() += this->_bias; + } + + // Advance ring buffer write pointer after processing + _input_buffer.Advance(num_frames); +} + +void Conv1D::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long ncols, + const long j_start) const +{ + // This is the clever part ;) + for (size_t k = 0; k < this->_weight.size(); k++) + { + const long offset = this->_dilation * (k + 1 - this->_weight.size()); + if (k == 0) + output.middleCols(j_start, ncols).noalias() = this->_weight[k] * input.middleCols(i_start + offset, ncols); + else + output.middleCols(j_start, ncols).noalias() += this->_weight[k] * input.middleCols(i_start + offset, ncols); + } + if (this->_bias.size() > 0) + { + output.middleCols(j_start, ncols).colwise() += this->_bias; + } +} + +long Conv1D::get_num_weights() const +{ + long num_weights = this->_bias.size(); + for (size_t i = 0; i < this->_weight.size(); i++) + num_weights += this->_weight[i].size(); + return num_weights; +} +} // namespace nam diff --git a/NAM/conv1d.h b/NAM/conv1d.h new file mode 100644 index 0000000..44dec89 --- /dev/null +++ b/NAM/conv1d.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include "ring_buffer.h" + +namespace nam +{ +class Conv1D +{ +public: + Conv1D() { this->_dilation = 1; }; + Conv1D(const int in_channels, const int out_channels, const int kernel_size, const int bias, const int dilation) + { + set_size_(in_channels, out_channels, kernel_size, bias, dilation); + }; + void set_weights_(std::vector::iterator& weights); + void set_size_(const int in_channels, const int out_channels, const int kernel_size, const bool do_bias, + const int _dilation); + void set_size_and_weights_(const int in_channels, const int out_channels, const int kernel_size, const int _dilation, + const bool do_bias, std::vector::iterator& weights); + // Reset the ring buffer and pre-allocate output buffer + // :param sampleRate: Unused, for interface consistency + // :param maxBufferSize: Maximum buffer size for output buffer and to size ring buffer + void SetMaxBufferSize(const int maxBufferSize); + // Get the entire internal output buffer. This is intended for internal wiring + // between layers; callers should treat the buffer as pre-allocated storage + // and only consider the first `num_frames` columns valid for a given + // processing call. Slice with .leftCols(num_frames) as needed. + Eigen::MatrixXf& GetOutput() { return _output; } + const Eigen::MatrixXf& GetOutput() const { return _output; } + // Process input and write to internal output buffer + // :param input: Input matrix (channels x num_frames) + // :param num_frames: Number of frames to process + void Process(const Eigen::MatrixXf& input, const int num_frames); + // Process from input to output (legacy method, kept for compatibility) + // Rightmost indices of input go from i_start for ncols, + // Indices on output for from j_start (to j_start + ncols - i_start) + void process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long ncols, + const long j_start) const; + long get_in_channels() const { return this->_weight.size() > 0 ? this->_weight[0].cols() : 0; }; + long get_kernel_size() const { return this->_weight.size(); }; + long get_num_weights() const; + long get_out_channels() const { return this->_weight.size() > 0 ? this->_weight[0].rows() : 0; }; + int get_dilation() const { return this->_dilation; }; + bool has_bias() const { return this->_bias.size() > 0; }; + +protected: + // conv[kernel](cout, cin) + std::vector _weight; + Eigen::VectorXf _bias; + int _dilation; + +private: + RingBuffer _input_buffer; // Ring buffer for input (channels x buffer_size) + Eigen::MatrixXf _output; // Pre-allocated output buffer (out_channels x maxBufferSize) + int _max_buffer_size = 0; // Stored maxBufferSize +}; +} // namespace nam diff --git a/NAM/convnet.cpp b/NAM/convnet.cpp index 535fac6..3999f61 100644 --- a/NAM/convnet.cpp +++ b/NAM/convnet.cpp @@ -59,15 +59,69 @@ void nam::convnet::ConvNetBlock::set_weights_(const int in_channels, const int o this->activation = activations::Activation::get_activation(activation); } +void nam::convnet::ConvNetBlock::SetMaxBufferSize(const int maxBufferSize) +{ + this->conv.SetMaxBufferSize(maxBufferSize); + const long out_channels = get_out_channels(); + this->_output.resize(out_channels, maxBufferSize); + this->_output.setZero(); +} + +void nam::convnet::ConvNetBlock::Process(const Eigen::MatrixXf& input, const int num_frames) +{ + // Process input with Conv1D + this->conv.Process(input, num_frames); + + // Get output from Conv1D (this is a block reference to conv's _output buffer) + auto conv_output_block = this->conv.GetOutput().leftCols(num_frames); + + // Copy conv output to our own output buffer + this->_output.leftCols(num_frames) = conv_output_block; + + // Apply batchnorm if needed + if (this->_batchnorm) + { + // Batchnorm expects indices, so we use 0 to num_frames for our output matrix + this->batchnorm.process_(this->_output, 0, num_frames); + } + + // Apply activation + this->activation->apply(this->_output.leftCols(num_frames)); +} + +Eigen::Block nam::convnet::ConvNetBlock::GetOutput(const int num_frames) +{ + return this->_output.block(0, 0, this->_output.rows(), num_frames); +} + void nam::convnet::ConvNetBlock::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, - const long i_end) const + const long i_end) { const long ncols = i_end - i_start; - this->conv.process_(input, output, i_start, ncols, i_start); + // Extract input slice and process with Conv1D + Eigen::MatrixXf input_slice = input.middleCols(i_start, ncols); + this->conv.Process(input_slice, (int)ncols); + + // Get output from Conv1D (this is a block reference to _output buffer) + auto conv_output_block = this->conv.GetOutput().leftCols((int)ncols); + + // For batchnorm, we need a matrix reference (not a block) + // Create a temporary matrix from the block, process it, then copy back + Eigen::MatrixXf temp_output = conv_output_block; + + // Apply batchnorm if needed if (this->_batchnorm) - this->batchnorm.process_(output, i_start, i_end); + { + // Batchnorm expects indices, so we use 0 to ncols for our temp matrix + this->batchnorm.process_(temp_output, 0, ncols); + } - this->activation->apply(output.middleCols(i_start, ncols)); + // Apply activation + this->activation->apply(temp_output); + + // Copy to Conv1D's output buffer and to output matrix + conv_output_block = temp_output; + output.middleCols(i_start, ncols) = temp_output; } long nam::convnet::ConvNetBlock::get_out_channels() const @@ -102,9 +156,10 @@ nam::convnet::ConvNet::ConvNet(const int channels, const std::vector& dilat std::vector::iterator it = weights.begin(); for (size_t i = 0; i < dilations.size(); i++) this->_blocks[i].set_weights_(i == 0 ? 1 : channels, channels, dilations[i], batchnorm, activation, it); - this->_block_vals.resize(this->_blocks.size() + 1); - for (auto& matrix : this->_block_vals) - matrix.setZero(); + // Only need _block_vals for the head (one entry) + // Conv1D layers manage their own buffers now + this->_block_vals.resize(1); + this->_block_vals[0].setZero(); std::fill(this->_input_buffer.begin(), this->_input_buffer.end(), 0.0f); this->_head = _Head(channels, it); if (it != weights.end()) @@ -123,14 +178,49 @@ void nam::convnet::ConvNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const // Main computation! const long i_start = this->_input_buffer_offset; const long i_end = i_start + num_frames; - // TODO one unnecessary copy :/ #speed - for (auto i = i_start; i < i_end; i++) - this->_block_vals[0](0, i) = this->_input_buffer[i]; + + // Convert input buffer to matrix for first layer + Eigen::MatrixXf input_matrix(1, num_frames); + for (int i = 0; i < num_frames; i++) + input_matrix(0, i) = this->_input_buffer[i_start + i]; + + // Process through ConvNetBlock layers + // Each block now uses Conv1D's internal buffers via Process() and GetOutput() for (size_t i = 0; i < this->_blocks.size(); i++) - this->_blocks[i].process_(this->_block_vals[i], this->_block_vals[i + 1], i_start, i_end); - // TODO clean up this allocation - this->_head.process_(this->_block_vals[this->_blocks.size()], this->_head_output, i_start, i_end); - // Copy to required output array (TODO tighten this up) + { + // Get input for this block + Eigen::MatrixXf block_input; + if (i == 0) + { + // First block uses the input matrix + block_input = input_matrix; + } + else + { + // Subsequent blocks use output from previous block + auto prev_output = this->_blocks[i - 1].GetOutput(num_frames); + block_input = prev_output; // Copy to matrix + } + + // Process block (handles Conv1D, batchnorm, and activation internally) + this->_blocks[i].Process(block_input, num_frames); + } + + // Process head with output from last Conv1D + // Head still needs the old interface, so we need to provide it via a matrix + // We still need _block_vals[0] for the head interface + if (this->_block_vals[0].rows() != this->_blocks.back().get_out_channels() + || this->_block_vals[0].cols() != (long)this->_input_buffer.size()) + { + this->_block_vals[0].resize(this->_blocks.back().get_out_channels(), this->_input_buffer.size()); + } + // Copy last block output to _block_vals for head + auto last_output = this->_blocks.back().GetOutput(num_frames); + this->_block_vals[0].middleCols(i_start, num_frames) = last_output; + + this->_head.process_(this->_block_vals[0], this->_head_output, i_start, i_end); + + // Copy to required output array for (int s = 0; s < num_frames; s++) output[s] = this->_head_output(s); @@ -144,45 +234,41 @@ void nam::convnet::ConvNet::_verify_weights(const int channels, const std::vecto // TODO } +void nam::convnet::ConvNet::SetMaxBufferSize(const int maxBufferSize) +{ + nam::Buffer::SetMaxBufferSize(maxBufferSize); + + // Reset all ConvNetBlock instances with the new buffer size + for (auto& block : _blocks) + { + block.SetMaxBufferSize(maxBufferSize); + } +} + void nam::convnet::ConvNet::_update_buffers_(NAM_SAMPLE* input, const int num_frames) { this->Buffer::_update_buffers_(input, num_frames); const long buffer_size = (long)this->_input_buffer.size(); - if (this->_block_vals[0].rows() != 1 || this->_block_vals[0].cols() != buffer_size) + // Only need _block_vals[0] for the head + // Conv1D layers manage their own buffers now + if (this->_block_vals[0].rows() != this->_blocks.back().get_out_channels() + || this->_block_vals[0].cols() != buffer_size) { - this->_block_vals[0].resize(1, buffer_size); + this->_block_vals[0].resize(this->_blocks.back().get_out_channels(), buffer_size); this->_block_vals[0].setZero(); } - - for (size_t i = 1; i < this->_block_vals.size(); i++) - { - if (this->_block_vals[i].rows() == this->_blocks[i - 1].get_out_channels() - && this->_block_vals[i].cols() == buffer_size) - continue; // Already has correct size - this->_block_vals[i].resize(this->_blocks[i - 1].get_out_channels(), buffer_size); - this->_block_vals[i].setZero(); - } } void nam::convnet::ConvNet::_rewind_buffers_() { - // Need to rewind the block vals first because Buffer::rewind_buffers() - // resets the offset index - // The last _block_vals is the output of the last block and doesn't need to be - // rewound. - for (size_t k = 0; k < this->_block_vals.size() - 1; k++) - { - // We actually don't need to pull back a lot...just as far as the first - // input sample would grab from dilation - const long _dilation = this->_blocks[k].conv.get_dilation(); - for (long i = this->_receptive_field - _dilation, j = this->_input_buffer_offset - _dilation; - j < this->_input_buffer_offset; i++, j++) - for (long r = 0; r < this->_block_vals[k].rows(); r++) - this->_block_vals[k](r, i) = this->_block_vals[k](r, j); - } - // Now we can do the rest of the rewind + // Conv1D instances now manage their own ring buffers and handle rewinding internally + // So we don't need to rewind _block_vals for Conv1D layers + // We only need _block_vals for the head, and it doesn't need rewinding since it's only used + // for the current frame range + + // Just rewind the input buffer (for Buffer base class) this->Buffer::_rewind_buffers_(); } diff --git a/NAM/convnet.h b/NAM/convnet.h index 34cbaa0..2fab0f3 100644 --- a/NAM/convnet.h +++ b/NAM/convnet.h @@ -9,6 +9,9 @@ #include +#include "conv1d.h" +#include "dsp.h" + namespace nam { namespace convnet @@ -42,7 +45,13 @@ class ConvNetBlock ConvNetBlock() {}; void set_weights_(const int in_channels, const int out_channels, const int _dilation, const bool batchnorm, const std::string activation, std::vector::iterator& weights); - void process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long i_end) const; + void SetMaxBufferSize(const int maxBufferSize); + // Process input matrix directly (new API, similar to WaveNet) + void Process(const Eigen::MatrixXf& input, const int num_frames); + // Legacy method for compatibility (uses indices) + void process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long i_end); + // Get output from last Process() call + Eigen::Block GetOutput(const int num_frames); long get_out_channels() const; Conv1D conv; @@ -50,6 +59,7 @@ class ConvNetBlock BatchNorm batchnorm; bool _batchnorm = false; activations::Activation* activation = nullptr; + Eigen::MatrixXf _output; // Output buffer owned by the block }; class _Head @@ -72,6 +82,7 @@ class ConvNet : public Buffer ~ConvNet() = default; void process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames) override; + void SetMaxBufferSize(const int maxBufferSize) override; protected: std::vector _blocks; diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index f3b5c14..8940314 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -204,68 +204,7 @@ std::unique_ptr nam::linear::Factory(const nlohmann::json& config, std // NN modules ================================================================= -void nam::Conv1D::set_weights_(std::vector::iterator& weights) -{ - if (this->_weight.size() > 0) - { - const long out_channels = this->_weight[0].rows(); - const long in_channels = this->_weight[0].cols(); - // Crazy ordering because that's how it gets flattened. - for (auto i = 0; i < out_channels; i++) - for (auto j = 0; j < in_channels; j++) - for (size_t k = 0; k < this->_weight.size(); k++) - this->_weight[k](i, j) = *(weights++); - } - for (long i = 0; i < this->_bias.size(); i++) - this->_bias(i) = *(weights++); -} - -void nam::Conv1D::set_size_(const int in_channels, const int out_channels, const int kernel_size, const bool do_bias, - const int _dilation) -{ - this->_weight.resize(kernel_size); - for (size_t i = 0; i < this->_weight.size(); i++) - this->_weight[i].resize(out_channels, - in_channels); // y = Ax, input array (C,L) - if (do_bias) - this->_bias.resize(out_channels); - else - this->_bias.resize(0); - this->_dilation = _dilation; -} - -void nam::Conv1D::set_size_and_weights_(const int in_channels, const int out_channels, const int kernel_size, - const int _dilation, const bool do_bias, std::vector::iterator& weights) -{ - this->set_size_(in_channels, out_channels, kernel_size, do_bias, _dilation); - this->set_weights_(weights); -} - -void nam::Conv1D::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long ncols, - const long j_start) const -{ - // This is the clever part ;) - for (size_t k = 0; k < this->_weight.size(); k++) - { - const long offset = this->_dilation * (k + 1 - this->_weight.size()); - if (k == 0) - output.middleCols(j_start, ncols).noalias() = this->_weight[k] * input.middleCols(i_start + offset, ncols); - else - output.middleCols(j_start, ncols).noalias() += this->_weight[k] * input.middleCols(i_start + offset, ncols); - } - if (this->_bias.size() > 0) - { - output.middleCols(j_start, ncols).colwise() += this->_bias; - } -} - -long nam::Conv1D::get_num_weights() const -{ - long num_weights = this->_bias.size(); - for (size_t i = 0; i < this->_weight.size(); i++) - num_weights += this->_weight[i].size(); - return num_weights; -} +// Conv1x1 ==================================================================== nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool _bias) { @@ -275,10 +214,6 @@ nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool this->_bias.resize(out_channels); } -Eigen::Block nam::Conv1x1::GetOutput(const int num_frames) -{ - return _output.block(0, 0, _output.rows(), num_frames); -} void nam::Conv1x1::SetMaxBufferSize(const int maxBufferSize) { diff --git a/NAM/dsp.h b/NAM/dsp.h index d22d5e8..3f9df92 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -173,40 +173,17 @@ std::unique_ptr Factory(const nlohmann::json& config, std::vector& w // NN modules ================================================================= -// TODO conv could take care of its own ring buffer. -class Conv1D -{ -public: - Conv1D() { this->_dilation = 1; }; - void set_weights_(std::vector::iterator& weights); - void set_size_(const int in_channels, const int out_channels, const int kernel_size, const bool do_bias, - const int _dilation); - void set_size_and_weights_(const int in_channels, const int out_channels, const int kernel_size, const int _dilation, - const bool do_bias, std::vector::iterator& weights); - // Process from input to output - // Rightmost indices of input go from i_start for ncols, - // Indices on output for from j_start (to j_start + ncols - i_start) - void process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long ncols, - const long j_start) const; - long get_in_channels() const { return this->_weight.size() > 0 ? this->_weight[0].cols() : 0; }; - long get_kernel_size() const { return this->_weight.size(); }; - long get_num_weights() const; - long get_out_channels() const { return this->_weight.size() > 0 ? this->_weight[0].rows() : 0; }; - int get_dilation() const { return this->_dilation; }; - -protected: - // conv[kernel](cout, cin) - std::vector _weight; - Eigen::VectorXf _bias; - int _dilation; -}; - // Really just a linear layer class Conv1x1 { public: Conv1x1(const int in_channels, const int out_channels, const bool _bias); - Eigen::Block GetOutput(const int num_frames); + // Get the entire internal output buffer. This is intended for internal wiring + // between layers/arrays; callers should treat the buffer as pre-allocated + // storage and only consider the first `num_frames` columns valid for a given + // processing call. Slice with .leftCols(num_frames) as needed. + Eigen::MatrixXf& GetOutput() { return _output; } + const Eigen::MatrixXf& GetOutput() const { return _output; } void SetMaxBufferSize(const int maxBufferSize); void set_weights_(std::vector::iterator& weights); // :param input: (N,Cin) or (Cin,) diff --git a/NAM/ring_buffer.cpp b/NAM/ring_buffer.cpp new file mode 100644 index 0000000..64518b4 --- /dev/null +++ b/NAM/ring_buffer.cpp @@ -0,0 +1,108 @@ +#include "ring_buffer.h" +#include + +namespace nam +{ + +void RingBuffer::Reset(const int channels, const int max_buffer_size) +{ + // Store the max buffer size for external queries + _max_buffer_size = max_buffer_size; + + // Calculate storage size: 2 * max_lookback + max_buffer_size + // This ensures we have enough room for: + // - max_lookback at the start (for history after rewind) + // - max_buffer_size in the middle (for writes/reads) + // - no aliasing when rewinding + const long storage_size = 2 * _max_lookback + max_buffer_size; + _storage.resize(channels, storage_size); + _storage.setZero(); + // Initialize write position to max_lookback to leave room for history + // Zero the storage behind the starting write position (for lookback) + if (_max_lookback > 0) + { + _storage.leftCols(_max_lookback).setZero(); + } + _write_pos = _max_lookback; +} + +void RingBuffer::Write(const Eigen::MatrixXf& input, const int num_frames) +{ + // Assert that num_frames doesn't exceed max buffer size + assert(num_frames <= _max_buffer_size && "Write: num_frames must not exceed max_buffer_size"); + + // Check if we need to rewind + if (NeedsRewind(num_frames)) + Rewind(); + + // Write the input data at the write position + // NOTE: This function assumes that `input` is a full, pre-allocated MatrixXf + // covering the entire valid buffer range. Callers should not pass Block + // expressions across the API boundary; instead, pass the full buffer and + // slice inside the callee. This avoids Eigen evaluating Blocks into + // temporaries (which would allocate) when binding to MatrixXf. + const int channels = _storage.rows(); + const int copy_cols = num_frames; + + for (int col = 0; col < copy_cols; ++col) + { + for (int row = 0; row < channels; ++row) + { + _storage(row, _write_pos + col) = input(row, col); + } + } +} + +Eigen::Block RingBuffer::Read(const int num_frames, const long lookback) +{ + // Assert that lookback doesn't exceed max_lookback + assert(lookback <= _max_lookback && "Read: lookback must not exceed max_lookback"); + + // Assert that num_frames doesn't exceed max buffer size + assert(num_frames <= _max_buffer_size && "Read: num_frames must not exceed max_buffer_size"); + + long read_pos = _write_pos - lookback; + + // Assert that read_pos is non-negative + // (Asserted by the access to _storage.block()) + return _storage.block(0, read_pos, _storage.rows(), num_frames); +} + +void RingBuffer::Advance(const int num_frames) +{ + _write_pos += num_frames; +} + +bool RingBuffer::NeedsRewind(const int num_frames) const +{ + return _write_pos + num_frames > (long)_storage.cols(); +} + +void RingBuffer::Rewind() +{ + if (_max_lookback == 0) + { + // No history to preserve, just reset + _write_pos = 0; + return; + } + + // Assert that write pointer is far enough along to avoid aliasing when copying + // We copy from position (_write_pos - _max_lookback) to position 0 + // For no aliasing, we need: _write_pos - _max_lookback >= _max_lookback + // Which simplifies to: _write_pos >= 2 * _max_lookback + assert(_write_pos >= 2 * _max_lookback + && "Write pointer must be at least 2 * max_lookback to avoid aliasing during rewind"); + + // Copy the max lookback amount of history back to the beginning + // This is the history that will be needed for lookback reads + const long copy_start = _write_pos - _max_lookback; + assert(copy_start >= 0 && copy_start < (long)_storage.cols() && "Copy start position must be within storage bounds"); + + // Copy _max_lookback samples from before the write position to the start + _storage.leftCols(_max_lookback) = _storage.middleCols(copy_start, _max_lookback); + + // Reset write position to just after the copied history + _write_pos = _max_lookback; +} +} // namespace nam diff --git a/NAM/ring_buffer.h b/NAM/ring_buffer.h new file mode 100644 index 0000000..f2c3dfe --- /dev/null +++ b/NAM/ring_buffer.h @@ -0,0 +1,56 @@ +#pragma once + +#include + +namespace nam +{ +// Ring buffer for managing Eigen::MatrixXf buffers with write/read pointers +class RingBuffer +{ +public: + RingBuffer() {}; + // Initialize/resize storage + // :param channels: Number of channels (rows in the storage matrix) + // :param max_buffer_size: Maximum amount that will be written or read at once + void Reset(const int channels, const int max_buffer_size); + // Write new data at write pointer + // :param input: Input matrix (channels x num_frames) + // :param num_frames: Number of frames to write + // NOTE: This function expects a full, pre-allocated, column-major MatrixXf + // covering the entire valid buffer range. Callers should not pass + // Block expressions (e.g. .leftCols()) across the API boundary; instead, + // pass the full buffer and slice inside the callee. This avoids Eigen + // evaluating Blocks into temporaries (which would allocate) when + // binding to MatrixXf. + void Write(const Eigen::MatrixXf& input, const int num_frames); + // Read data with optional lookback + // :param num_frames: Number of frames to read + // :param lookback: Number of frames to look back from write pointer (default 0) + // :return: Block reference to the storage data + Eigen::Block Read(const int num_frames, const long lookback = 0); + // Advance write pointer + // :param num_frames: Number of frames to advance + void Advance(const int num_frames); + // Get max buffer size (the value passed to Reset()) + int GetMaxBufferSize() const { return _max_buffer_size; } + // Get number of channels (rows) + int GetChannels() const { return _storage.rows(); } + // Set the max lookback (maximum history needed when rewinding) + void SetMaxLookback(const long max_lookback) { _max_lookback = max_lookback; } + +private: + // Wrap buffer when approaching end (called automatically if needed) + void Rewind(); + // Check if rewind is needed before `num_frames` are written or read + // :param num_frames: Number of frames that will be written + // :return: true if rewind is needed + bool NeedsRewind(const int num_frames) const; + // Get current write position + long GetWritePos() const { return _write_pos; } + + Eigen::MatrixXf _storage; // channels x storage_size + long _write_pos = 0; // Current write position + long _max_lookback = 0; // Maximum lookback needed when rewinding + int _max_buffer_size = 0; // Maximum buffer size passed to Reset() +}; +} // namespace nam diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 955d3d1..eca40f3 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -7,19 +7,18 @@ #include "registry.h" #include "wavenet.h" -nam::wavenet::_DilatedConv::_DilatedConv(const int in_channels, const int out_channels, const int kernel_size, - const int bias, const int dilation) -{ - this->set_size_(in_channels, out_channels, kernel_size, bias, dilation); -} - // Layer ====================================================================== void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize) { + _conv.SetMaxBufferSize(maxBufferSize); _input_mixin.SetMaxBufferSize(maxBufferSize); _z.resize(this->_conv.get_out_channels(), maxBufferSize); _1x1.SetMaxBufferSize(maxBufferSize); + // Pre-allocate output buffers + const long channels = this->get_channels(); + this->_output_next_layer.resize(channels, maxBufferSize); + this->_output_head.resize(channels, maxBufferSize); } void nam::wavenet::_Layer::set_weights_(std::vector::iterator& weights) @@ -29,21 +28,21 @@ void nam::wavenet::_Layer::set_weights_(std::vector::iterator& weights) this->_1x1.set_weights_(weights); } -void nam::wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, - Eigen::MatrixXf& head_input, Eigen::MatrixXf& output, const long i_start, - const long j_start, const int num_frames) +void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames) { - const long ncols = (long)num_frames; // TODO clean this up const long channels = this->get_channels(); - // Input dilated conv - this->_conv.process_(input, this->_z, i_start, ncols, 0); - // Mix-in condition - _input_mixin.process_(condition, num_frames); - this->_z.leftCols(num_frames).noalias() += _input_mixin.GetOutput(num_frames); + // Step 1: input convolutions + this->_conv.Process(input, num_frames); + this->_input_mixin.process_(condition, num_frames); + this->_z.leftCols(num_frames).noalias() = + this->_conv.GetOutput().leftCols(num_frames) + _input_mixin.GetOutput().leftCols(num_frames); + + // Step 2 & 3: activation and 1x1 if (!this->_gated) { this->_activation->apply(this->_z.leftCols(num_frames)); + _1x1.process_(_z, num_frames); } else { @@ -56,35 +55,22 @@ void nam::wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::M activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, i, channels, 1)); } this->_z.block(0, 0, channels, num_frames).array() *= this->_z.block(channels, 0, channels, num_frames).array(); + _1x1.process_(_z.topRows(channels), num_frames); // Might not be RT safe } - head_input.leftCols(num_frames).noalias() += this->_z.block(0, 0, channels, num_frames); - if (!_gated) - { - _1x1.process_(_z, num_frames); - } + // Store output to head (skip connection: activated conv output) + if (!this->_gated) + this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames); else - { - // Probably not RT-safe yet - _1x1.process_(_z.topRows(channels), num_frames); - } - output.middleCols(j_start, ncols).noalias() = input.middleCols(i_start, ncols) + _1x1.GetOutput(num_frames); + this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(channels).leftCols(num_frames); + // Store output to next layer (residual connection: input + _1x1 output) + this->_output_next_layer.leftCols(num_frames).noalias() = + input.leftCols(num_frames) + _1x1.GetOutput().leftCols(num_frames); } -void nam::wavenet::_Layer::set_num_frames_(const long num_frames) -{ - // TODO deprecate for SetMaxBufferSize() - if (this->_z.rows() == this->_conv.get_out_channels() && this->_z.cols() == num_frames) - return; // Already has correct size - - this->_z.resize(this->_conv.get_out_channels(), num_frames); - this->_z.setZero(); -} // LayerArray ================================================================= -#define LAYER_ARRAY_BUFFER_SIZE 65536 - nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition_size, const int head_size, const int channels, const int kernel_size, const std::vector& dilations, const std::string activation, const bool gated, const bool head_bias) @@ -93,13 +79,6 @@ nam::wavenet::_LayerArray::_LayerArray(const int input_size, const int condition { for (size_t i = 0; i < dilations.size(); i++) this->_layers.push_back(_Layer(condition_size, channels, kernel_size, dilations[i], activation, gated)); - const long receptive_field = this->_get_receptive_field(); - for (size_t i = 0; i < dilations.size(); i++) - { - this->_layer_buffers.push_back(Eigen::MatrixXf(channels, LAYER_ARRAY_BUFFER_SIZE + receptive_field - 1)); - this->_layer_buffers[i].setZero(); - } - this->_buffer_start = this->_get_receptive_field() - 1; } void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize) @@ -110,12 +89,12 @@ void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize) { it->SetMaxBufferSize(maxBufferSize); } + // Pre-allocate output buffers + const long channels = this->_get_channels(); + this->_layer_outputs.resize(channels, maxBufferSize); + this->_head_inputs.resize(channels, maxBufferSize); } -void nam::wavenet::_LayerArray::advance_buffers_(const int num_frames) -{ - this->_buffer_start += num_frames; -} long nam::wavenet::_LayerArray::get_receptive_field() const { @@ -125,150 +104,83 @@ long nam::wavenet::_LayerArray::get_receptive_field() const return result; } -void nam::wavenet::_LayerArray::prepare_for_frames_(const long num_frames) -{ - // Example: - // _buffer_start = 0 - // num_frames = 64 - // buffer_size = 64 - // -> this will write on indices 0 through 63, inclusive. - // -> No illegal writes. - // -> no rewind needed. - if (this->_buffer_start + num_frames > this->_get_buffer_size()) - this->_rewind_buffers_(); -} - -void nam::wavenet::_LayerArray::process_(const Eigen::MatrixXf& layer_inputs, const Eigen::MatrixXf& condition, - Eigen::MatrixXf& head_inputs, Eigen::MatrixXf& layer_outputs, - Eigen::MatrixXf& head_outputs, const int num_frames) -{ - this->_rechannel.process_(layer_inputs, num_frames); - this->_layer_buffers[0].middleCols(this->_buffer_start, num_frames) = _rechannel.GetOutput(num_frames); - const size_t last_layer = this->_layers.size() - 1; - for (size_t i = 0; i < this->_layers.size(); i++) - { - this->_layers[i].process_(this->_layer_buffers[i], condition, head_inputs, - i == last_layer ? layer_outputs : this->_layer_buffers[i + 1], this->_buffer_start, - i == last_layer ? 0 : this->_buffer_start, num_frames); - } - _head_rechannel.process_(head_inputs, num_frames); - head_outputs.leftCols(num_frames) = _head_rechannel.GetOutput(num_frames); -} -void nam::wavenet::_LayerArray::set_num_frames_(const long num_frames) +void nam::wavenet::_LayerArray::Process(const Eigen::MatrixXf& layer_inputs, const Eigen::MatrixXf& condition, + const int num_frames) { - // Wavenet checks for unchanged num_frames; if we made it here, there's - // something to do. - if (LAYER_ARRAY_BUFFER_SIZE - num_frames < this->_get_receptive_field()) - { - std::stringstream ss; - ss << "Asked to accept a buffer of " << num_frames << " samples, but the buffer is too short (" - << LAYER_ARRAY_BUFFER_SIZE << ") to get out of the recptive field (" << this->_get_receptive_field() - << "); copy errors could occur!\n"; - throw std::runtime_error(ss.str().c_str()); - } - for (size_t i = 0; i < this->_layers.size(); i++) - this->_layers[i].set_num_frames_(num_frames); + // Zero head inputs accumulator (first layer array) + this->_head_inputs.setZero(); + ProcessInner(layer_inputs, condition, num_frames); } -void nam::wavenet::_LayerArray::set_weights_(std::vector::iterator& weights) +void nam::wavenet::_LayerArray::Process(const Eigen::MatrixXf& layer_inputs, const Eigen::MatrixXf& condition, + const Eigen::MatrixXf& head_inputs, const int num_frames) { - this->_rechannel.set_weights_(weights); - for (size_t i = 0; i < this->_layers.size(); i++) - this->_layers[i].set_weights_(weights); - this->_head_rechannel.set_weights_(weights); + // Copy head inputs from previous layer array + this->_head_inputs.leftCols(num_frames).noalias() = head_inputs.leftCols(num_frames); + ProcessInner(layer_inputs, condition, num_frames); } -long nam::wavenet::_LayerArray::_get_channels() const +void nam::wavenet::_LayerArray::ProcessInner(const Eigen::MatrixXf& layer_inputs, const Eigen::MatrixXf& condition, + const int num_frames) { - return this->_layers.size() > 0 ? this->_layers[0].get_channels() : 0; -} + // Process rechannel and get output + this->_rechannel.process_(layer_inputs, num_frames); + Eigen::MatrixXf& rechannel_output = _rechannel.GetOutput(); -long nam::wavenet::_LayerArray::_get_receptive_field() const -{ - // TODO remove this and use get_receptive_field() instead! - long res = 1; + // Process layers for (size_t i = 0; i < this->_layers.size(); i++) - res += (this->_layers[i].get_kernel_size() - 1) * this->_layers[i].get_dilation(); - return res; -} - -void nam::wavenet::_LayerArray::_rewind_buffers_() -// Consider wrapping instead... -// Can make this smaller--largest dilation, not receptive field! -{ - const long start = this->_get_receptive_field() - 1; - for (size_t i = 0; i < this->_layer_buffers.size(); i++) { - const long d = (this->_layers[i].get_kernel_size() - 1) * this->_layers[i].get_dilation(); - this->_layer_buffers[i].middleCols(start - d, d) = this->_layer_buffers[i].middleCols(this->_buffer_start - d, d); + // Process first layer with rechannel output, subsequent layers with previous layer output + // Use separate branches to avoid ternary operator creating temporaries + if (i == 0) + { + // First layer consumes the rechannel output buffer + this->_layers[i].Process(rechannel_output, condition, num_frames); + } + else + { + // Subsequent layers consume the full output buffer of the previous layer + Eigen::MatrixXf& prev_output = this->_layers[i - 1].GetOutputNextLayer(); + this->_layers[i].Process(prev_output, condition, num_frames); + } + + // Accumulate head output from this layer + this->_head_inputs.leftCols(num_frames).noalias() += this->_layers[i].GetOutputHead().leftCols(num_frames); } - this->_buffer_start = start; -} -// Head ======================================================================= + // Store output from last layer + const size_t last_layer = this->_layers.size() - 1; + this->_layer_outputs.leftCols(num_frames).noalias() = + this->_layers[last_layer].GetOutputNextLayer().leftCols(num_frames); -nam::wavenet::_Head::_Head(const int input_size, const int num_layers, const int channels, const std::string activation) -: _channels(channels) -, _head(num_layers > 0 ? channels : input_size, 1, true) -, _activation(activations::Activation::get_activation(activation)) -{ - assert(num_layers > 0); - int dx = input_size; - for (int i = 0; i < num_layers; i++) - { - this->_layers.push_back(Conv1x1(dx, i == num_layers - 1 ? 1 : channels, true)); - dx = channels; - if (i < num_layers - 1) - this->_buffers.push_back(Eigen::MatrixXf()); - } + // Process head rechannel + _head_rechannel.process_(this->_head_inputs, num_frames); } -void nam::wavenet::_Head::Reset(const double sampleRate, const int maxBufferSize) -{ - set_num_frames_((long)maxBufferSize); -} -void nam::wavenet::_Head::set_weights_(std::vector::iterator& weights) +Eigen::MatrixXf& nam::wavenet::_LayerArray::GetHeadOutputs() { - for (size_t i = 0; i < this->_layers.size(); i++) - this->_layers[i].set_weights_(weights); + return this->_head_rechannel.GetOutput(); } -void nam::wavenet::_Head::process_(Eigen::MatrixXf& inputs, Eigen::MatrixXf& outputs) +const Eigen::MatrixXf& nam::wavenet::_LayerArray::GetHeadOutputs() const { - const size_t num_layers = this->_layers.size(); - this->_apply_activation_(inputs); - if (num_layers == 1) - outputs = this->_layers[0].process(inputs); - else - { - this->_buffers[0] = this->_layers[0].process(inputs); - for (size_t i = 1; i < num_layers; i++) - { // Asserted > 0 layers - this->_apply_activation_(this->_buffers[i - 1]); - if (i < num_layers - 1) - this->_buffers[i] = this->_layers[i].process(this->_buffers[i - 1]); - else - outputs = this->_layers[i].process(this->_buffers[i - 1]); - } - } + return this->_head_rechannel.GetOutput(); } -void nam::wavenet::_Head::set_num_frames_(const long num_frames) + +void nam::wavenet::_LayerArray::set_weights_(std::vector::iterator& weights) { - for (size_t i = 0; i < this->_buffers.size(); i++) - { - if (this->_buffers[i].rows() == this->_channels && this->_buffers[i].cols() == num_frames) - continue; // Already has correct size - this->_buffers[i].resize(this->_channels, num_frames); - this->_buffers[i].setZero(); // Shouldn't be needed--these are written to before they're used. - } + this->_rechannel.set_weights_(weights); + for (size_t i = 0; i < this->_layers.size(); i++) + this->_layers[i].set_weights_(weights); + this->_head_rechannel.set_weights_(weights); } -void nam::wavenet::_Head::_apply_activation_(Eigen::MatrixXf& x) +long nam::wavenet::_LayerArray::_get_channels() const { - this->_activation->apply(x); + return this->_layers.size() > 0 ? this->_layers[0].get_channels() : 0; } // WaveNet ==================================================================== @@ -287,9 +199,6 @@ nam::wavenet::WaveNet::WaveNet(const std::vector layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size, layer_array_params[i].channels, layer_array_params[i].kernel_size, layer_array_params[i].dilations, layer_array_params[i].activation, layer_array_params[i].gated, layer_array_params[i].head_bias)); - this->_layer_array_outputs.push_back(Eigen::MatrixXf(layer_array_params[i].channels, 0)); - if (i == 0) - this->_head_arrays.push_back(Eigen::MatrixXf(layer_array_params[i].channels, 0)); if (i > 0) if (layer_array_params[i].channels != layer_array_params[i - 1].head_size) { @@ -298,9 +207,7 @@ nam::wavenet::WaveNet::WaveNet(const std::vector << ") doesn't match head_size of preceding layer (" << layer_array_params[i - 1].head_size << "!\n"; throw std::runtime_error(ss.str().c_str()); } - this->_head_arrays.push_back(Eigen::MatrixXf(layer_array_params[i].head_size, 0)); } - this->_head_output.resize(1, 0); // Mono output! this->set_weights_(weights); mPrewarmSamples = 1; @@ -313,7 +220,6 @@ void nam::wavenet::WaveNet::set_weights_(std::vector& weights) std::vector::iterator it = weights.begin(); for (size_t i = 0; i < this->_layer_arrays.size(); i++) this->_layer_arrays[i].set_weights_(it); - // this->_head.set_params_(it); this->_head_scale = *(it++); if (it != weights.end()) { @@ -332,30 +238,9 @@ void nam::wavenet::WaveNet::set_weights_(std::vector& weights) void nam::wavenet::WaveNet::SetMaxBufferSize(const int maxBufferSize) { DSP::SetMaxBufferSize(maxBufferSize); - this->_condition.resize(this->_get_condition_dim(), maxBufferSize); - for (size_t i = 0; i < this->_head_arrays.size(); i++) - this->_head_arrays[i].resize(this->_head_arrays[i].rows(), maxBufferSize); - for (size_t i = 0; i < this->_layer_array_outputs.size(); i++) - this->_layer_array_outputs[i].resize(this->_layer_array_outputs[i].rows(), maxBufferSize); - this->_head_output.resize(this->_head_output.rows(), maxBufferSize); - this->_head_output.setZero(); - for (size_t i = 0; i < this->_layer_arrays.size(); i++) this->_layer_arrays[i].SetMaxBufferSize(maxBufferSize); - // this->_head.SetMaxBufferSize(maxBufferSize); -} - -void nam::wavenet::WaveNet::_advance_buffers_(const int num_frames) -{ - for (size_t i = 0; i < this->_layer_arrays.size(); i++) - this->_layer_arrays[i].advance_buffers_(num_frames); -} - -void nam::wavenet::WaveNet::_prepare_for_frames_(const long num_frames) -{ - for (size_t i = 0; i < this->_layer_arrays.size(); i++) - this->_layer_arrays[i].prepare_for_frames_(num_frames); } void nam::wavenet::WaveNet::_set_condition_array(NAM_SAMPLE* input, const int num_frames) @@ -369,35 +254,37 @@ void nam::wavenet::WaveNet::_set_condition_array(NAM_SAMPLE* input, const int nu void nam::wavenet::WaveNet::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_frames) { assert(num_frames <= mMaxBufferSize); - this->_prepare_for_frames_(num_frames); this->_set_condition_array(input, num_frames); // Main layer arrays: // Layer-to-layer - // Sum on head output - this->_head_arrays[0].setZero(); for (size_t i = 0; i < this->_layer_arrays.size(); i++) - this->_layer_arrays[i].process_(i == 0 ? this->_condition : this->_layer_array_outputs[i - 1], this->_condition, - this->_head_arrays[i], this->_layer_array_outputs[i], this->_head_arrays[i + 1], - num_frames); - // this->_head.process_( - // this->_head_input, - // this->_head_output - //); - // Copy to required output array - // Hack: apply head scale here; revisit when/if I activate the head. - // assert(this->_head_output.rows() == 1); - - const long final_head_array = this->_head_arrays.size() - 1; - assert(this->_head_arrays[final_head_array].rows() == 1); + { + if (i == 0) + { + // First layer array - no head input + this->_layer_arrays[i].Process(this->_condition, this->_condition, num_frames); + } + else + { + // Subsequent layer arrays - use outputs from previous layer array. + // Pass full buffers and slice inside the callee to avoid passing Blocks + // across API boundaries (which can cause Eigen to allocate temporaries). + Eigen::MatrixXf& prev_layer_outputs = this->_layer_arrays[i - 1].GetLayerOutputs(); + Eigen::MatrixXf& prev_head_outputs = this->_layer_arrays[i - 1].GetHeadOutputs(); + this->_layer_arrays[i].Process(prev_layer_outputs, this->_condition, prev_head_outputs, num_frames); + } + } + + // (Head not implemented) + + auto& final_head_outputs = this->_layer_arrays.back().GetHeadOutputs(); + assert(final_head_outputs.rows() == 1); for (int s = 0; s < num_frames; s++) { - const float out = this->_head_scale * this->_head_arrays[final_head_array](0, s); + const float out = this->_head_scale * final_head_outputs(0, s); output[s] = out; } - - // Finalize to prepare for the next call: - this->_advance_buffers_(num_frames); } // Factory to instantiate from nlohmann json diff --git a/NAM/wavenet.h b/NAM/wavenet.h index 60550d9..63881ca 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -7,19 +7,12 @@ #include #include "dsp.h" +#include "conv1d.h" namespace nam { namespace wavenet { -// Rework the initialization API slightly. Merge w/ dsp.h later. -class _DilatedConv : public Conv1D -{ -public: - _DilatedConv(const int in_channels, const int out_channels, const int kernel_size, const int bias, - const int dilation); -}; - class _Layer { public: @@ -37,15 +30,9 @@ class _Layer // Process a block of frames. // :param `input`: from previous layer // :param `condition`: conditioning input (input to the WaveNet / "skip-in") - // :param `head_input`: input to the head ("skip-out") - // :param `output`: to next layer - // :param `i_start`: Index of the first column of the input samples that the conv layer's first kernel will process - // :param `j_start`: Index of the first column of the output block that will be written to // :param `num_frames`: number of frames to process - void process_(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, Eigen::MatrixXf& head_input, - Eigen::MatrixXf& output, const long i_start, const long j_start, const int num_frames); - // DEPRECATED - use SetMaxBufferSize() instead - void set_num_frames_(const long num_frames); + // Outputs are stored internally and accessible via GetOutputNextLayer() and GetOutputHead() + void Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames); // The number of channels expected as input/output from this layer long get_channels() const { return this->_conv.get_in_channels(); }; // Dilation of the input convolution layer @@ -53,15 +40,34 @@ class _Layer // Kernel size of the input convolution layer long get_kernel_size() const { return this->_conv.get_kernel_size(); }; + // Get output to next layer (residual connection: input + _1x1 output) + // Returns the full pre-allocated buffer; only the first `num_frames` columns + // are valid for a given processing call. Slice with .leftCols(num_frames) as needed. + Eigen::MatrixXf& GetOutputNextLayer() { return this->_output_next_layer; } + const Eigen::MatrixXf& GetOutputNextLayer() const { return this->_output_next_layer; } + // Get output to head (skip connection: activated conv output) + // Returns the full pre-allocated buffer; only the first `num_frames` columns + // are valid for a given processing call. Slice with .leftCols(num_frames) as needed. + Eigen::MatrixXf& GetOutputHead() { return this->_output_head; } + const Eigen::MatrixXf& GetOutputHead() const { return this->_output_head; } + + // Access Conv1D for Reset() propagation (needed for _LayerArray) + Conv1D& get_conv() { return _conv; } + const Conv1D& get_conv() const { return _conv; } + private: // The dilated convolution at the front of the block - _DilatedConv _conv; + Conv1D _conv; // Input mixin Conv1x1 _input_mixin; // The post-activation 1x1 convolution Conv1x1 _1x1; // The internal state Eigen::MatrixXf _z; + // Output to next layer (residual connection: input + _1x1 output) + Eigen::MatrixXf _output_next_layer; + // Output to head (skip connection: activated conv output) + Eigen::MatrixXf _output_head; activations::Activation* _activation; const bool _gated; @@ -106,22 +112,26 @@ class _LayerArray void SetMaxBufferSize(const int maxBufferSize); - void advance_buffers_(const int num_frames); - - // Preparing for frames: - // Rewind buffers if needed - // Shift index to prepare - // - void prepare_for_frames_(const long num_frames); - // All arrays are "short". - void process_(const Eigen::MatrixXf& layer_inputs, // Short - const Eigen::MatrixXf& condition, // Short - Eigen::MatrixXf& layer_outputs, // Short - Eigen::MatrixXf& head_inputs, // Sum up on this. - Eigen::MatrixXf& head_outputs, // post head-rechannel - const int num_frames); - void set_num_frames_(const long num_frames); + // Process without head input (first layer array) - zeros head inputs before proceeding + void Process(const Eigen::MatrixXf& layer_inputs, // Short + const Eigen::MatrixXf& condition, // Short + const int num_frames); + // Process with head input (subsequent layer arrays) - copies head input before proceeding + void Process(const Eigen::MatrixXf& layer_inputs, // Short + const Eigen::MatrixXf& condition, // Short + const Eigen::MatrixXf& head_inputs, // Short - from previous layer array + const int num_frames); + // Get output from last layer (for next layer array) + // Returns the full pre-allocated buffer; only the first `num_frames` columns + // are valid for a given processing call. Slice with .leftCols(num_frames) as needed. + Eigen::MatrixXf& GetLayerOutputs() { return this->_layer_outputs; } + const Eigen::MatrixXf& GetLayerOutputs() const { return this->_layer_outputs; } + // Get head outputs (post head-rechannel) + // Returns the full pre-allocated buffer; only the first `num_frames` columns + // are valid for a given processing call. Slice with .leftCols(num_frames) as needed. + Eigen::MatrixXf& GetHeadOutputs(); + const Eigen::MatrixXf& GetHeadOutputs() const; void set_weights_(std::vector::iterator& it); // "Zero-indexed" receptive field. @@ -129,54 +139,22 @@ class _LayerArray long get_receptive_field() const; private: - long _buffer_start; // The rechannel before the layers Conv1x1 _rechannel; - // Buffers in between layers. - // buffer [i] is the input to layer [i]. - // the last layer outputs to a short array provided by outside. - std::vector _layer_buffers; // The layer objects std::vector<_Layer> _layers; + // Output from last layer (for next layer array) + Eigen::MatrixXf _layer_outputs; + // Accumulated head inputs from all layers + Eigen::MatrixXf _head_inputs; // Rechannel for the head Conv1x1 _head_rechannel; - long _get_buffer_size() const { return this->_layer_buffers.size() > 0 ? this->_layer_buffers[0].cols() : 0; }; long _get_channels() const; - // "One-indexed" receptive field - // TODO remove! - // E.g. a 1x1 convolution has a o.i.r.f. of one. - long _get_receptive_field() const; - void _rewind_buffers_(); -}; - -// The head module -// [Act->Conv] x L -class _Head -{ -public: - _Head(const int input_size, const int num_layers, const int channels, const std::string activation); - void Reset(const double sampleRate, const int maxBufferSize); - void set_weights_(std::vector::iterator& weights); - // NOTE: the head transforms the provided input by applying a nonlinearity - // to it in-place! - void process_(Eigen::MatrixXf& inputs, Eigen::MatrixXf& outputs); - void set_num_frames_(const long num_frames); - -private: - int _channels; - std::vector _layers; - Conv1x1 _head; - activations::Activation* _activation; - - // Stores the outputs of the convs *except* the last one, which goes in - // The array `outputs` provided to .process_() - std::vector _buffers; - - // Apply the activation to the provided array, in-place - void _apply_activation_(Eigen::MatrixXf& x); + // Common processing logic after head inputs are set + void ProcessInner(const Eigen::MatrixXf& layer_inputs, const Eigen::MatrixXf& condition, const int num_frames); }; // The main WaveNet model @@ -202,20 +180,8 @@ class WaveNet : public DSP private: std::vector<_LayerArray> _layer_arrays; - // Their outputs - std::vector _layer_array_outputs; - // _Head _head; - // One more than total layer arrays - std::vector _head_arrays; float _head_scale; - Eigen::MatrixXf _head_output; - - void _advance_buffers_(const int num_frames); - void _prepare_for_frames_(const long num_frames); - - // Ensure that all buffer arrays are the right size for this num_frames - void _set_num_frames_(const long num_frames); int mPrewarmSamples = 0; // Pre-compute during initialization int PrewarmSamples() override { return mPrewarmSamples; }; diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 1e09436..fcc84d1 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -13,6 +13,8 @@ include_directories(tools ${NAM_DEPS_PATH}/nlohmann) add_executable(loadmodel loadmodel.cpp ${NAM_SOURCES}) add_executable(benchmodel benchmodel.cpp ${NAM_SOURCES}) add_executable(run_tests run_tests.cpp ${NAM_SOURCES}) +# Compile run_tests without optimizations to ensure allocation tracking works correctly +set_target_properties(run_tests PROPERTIES COMPILE_OPTIONS "-O0") source_group(NAM ${CMAKE_CURRENT_SOURCE_DIR} FILES ${NAM_SOURCES}) @@ -46,3 +48,4 @@ endif() # /Users/steve/src/NeuralAmpModelerCore/Dependencies/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h # Don't let this break my build on debug: set_source_files_properties(../NAM/dsp.cpp PROPERTIES COMPILE_FLAGS "-Wno-error") +set_source_files_properties(../NAM/conv1d.cpp PROPERTIES COMPILE_FLAGS "-Wno-error") \ No newline at end of file diff --git a/tools/benchmark_compare.sh b/tools/benchmark_compare.sh index 9ec2476..e742fd1 100755 --- a/tools/benchmark_compare.sh +++ b/tools/benchmark_compare.sh @@ -1,7 +1,7 @@ #!/bin/bash -# Script to compare performance of current branch against main -# Usage: ./tools/benchmark_compare.sh [--model MODEL_PATH] +# Script to compare performance of current branch against another branch (default: main) +# Usage: ./tools/benchmark_compare.sh [--model MODEL_PATH] [--branch BRANCH_NAME] set -e # Exit on error @@ -9,6 +9,7 @@ MODEL_PATH="example_models/wavenet_a1_standard.nam" BUILD_DIR="build" BENCHMARK_EXEC="build/tools/benchmodel" NUM_RUNS=10 +COMPARE_BRANCH="main" # Default branch to compare against # Report file will be set with timestamp in main() # Colors for output @@ -129,22 +130,23 @@ calculate_stats() { # Function to generate report generate_report() { - local main_results="$1" + local compare_results="$1" local current_results="$2" local current_branch="$3" - local main_commit="$4" - local current_commit="$5" - local report_file="$6" + local compare_branch="$4" + local compare_commit="$5" + local current_commit="$6" + local report_file="$7" echo "Generating performance comparison report..." # Calculate statistics for both branches - read main_mean main_median main_min main_max main_stddev <<< $(calculate_stats "$main_results") + read compare_mean compare_median compare_min compare_max compare_stddev <<< $(calculate_stats "$compare_results") read current_mean current_median current_min current_max current_stddev <<< $(calculate_stats "$current_results") # Calculate percentage difference - diff_mean=$(echo "scale=2; (($current_mean - $main_mean) / $main_mean) * 100" | bc) - diff_median=$(echo "scale=2; (($current_median - $main_median) / $main_median) * 100" | bc) + diff_mean=$(echo "scale=2; (($current_mean - $compare_mean) / $compare_mean) * 100" | bc) + diff_median=$(echo "scale=2; (($current_median - $compare_median) / $compare_median) * 100" | bc) # Generate report { @@ -157,14 +159,14 @@ generate_report() { echo "Date: $(date)" echo "" echo "----------------------------------------" - echo "Branch: main" + echo "Branch: $compare_branch" echo "----------------------------------------" - echo "Commit: ${main_commit}" - echo "Mean: ${main_mean} ms" - echo "Median: ${main_median} ms" - echo "Min: ${main_min} ms" - echo "Max: ${main_max} ms" - echo "Std Dev: ${main_stddev} ms" + echo "Commit: ${compare_commit}" + echo "Mean: ${compare_mean} ms" + echo "Median: ${compare_median} ms" + echo "Min: ${compare_min} ms" + echo "Max: ${compare_max} ms" + echo "Std Dev: ${compare_stddev} ms" echo "" echo "----------------------------------------" echo "Branch: $current_branch" @@ -180,18 +182,18 @@ generate_report() { echo "Comparison" echo "----------------------------------------" if (( $(echo "$diff_mean > 0" | bc -l) )); then - echo "Mean: ${current_branch} is ${diff_mean}% SLOWER than main" + echo "Mean: ${current_branch} is ${diff_mean}% SLOWER than ${compare_branch}" else - echo "Mean: ${current_branch} is ${diff_mean#-}% FASTER than main" + echo "Mean: ${current_branch} is ${diff_mean#-}% FASTER than ${compare_branch}" fi if (( $(echo "$diff_median > 0" | bc -l) )); then - echo "Median: ${current_branch} is ${diff_median}% SLOWER than main" + echo "Median: ${current_branch} is ${diff_median}% SLOWER than ${compare_branch}" else - echo "Median: ${current_branch} is ${diff_median#-}% FASTER than main" + echo "Median: ${current_branch} is ${diff_median#-}% FASTER than ${compare_branch}" fi echo "" - echo "Raw Results (main):" - cat "$main_results" | awk '{printf " %.3f ms\n", $1}' + echo "Raw Results ($compare_branch):" + cat "$compare_results" | awk '{printf " %.3f ms\n", $1}' echo "" echo "Raw Results ($current_branch):" cat "$current_results" | awk '{printf " %.3f ms\n", $1}' @@ -216,11 +218,21 @@ main() { MODEL_PATH="$2" shift 2 ;; + --branch) + if [ -z "$2" ]; then + echo -e "${RED}Error: --branch requires a branch name argument${NC}" + echo "Use --help for usage information" + exit 1 + fi + COMPARE_BRANCH="$2" + shift 2 + ;; --help|-h) - echo "Usage: $0 [--model MODEL_PATH]" + echo "Usage: $0 [--model MODEL_PATH] [--branch BRANCH_NAME]" echo "" echo "Options:" echo " --model MODEL_PATH Path to the model file to benchmark (default: example_models/wavenet_a1_standard.nam)" + echo " --branch BRANCH_NAME Branch to compare against (default: main)" echo " --help, -h Show this help message" exit 0 ;; @@ -246,13 +258,13 @@ main() { # Get current branch current_branch=$(git rev-parse --abbrev-ref HEAD) - if [ "$current_branch" = "main" ]; then - echo -e "${RED}Error: Already on main branch. Please checkout a different branch first.${NC}" + if [ "$current_branch" = "$COMPARE_BRANCH" ]; then + echo -e "${RED}Error: Already on $COMPARE_BRANCH branch. Please checkout a different branch first.${NC}" exit 1 fi echo -e "${YELLOW}Current branch: ${current_branch}${NC}" - echo -e "${YELLOW}Comparing against: main${NC}" + echo -e "${YELLOW}Comparing against: ${COMPARE_BRANCH}${NC}" echo "" # Generate timestamped report filename @@ -260,11 +272,11 @@ main() { REPORT_FILE="benchmark_report_${TIMESTAMP}.txt" # Create temporary files for results - main_results=$(mktemp) + compare_results=$(mktemp) current_results=$(mktemp) # Variables to store commit hashes - main_commit="" + compare_commit="" current_commit="" # Save untracked model file if it exists (to preserve it across branch switches) @@ -280,7 +292,7 @@ main() { # Cleanup function cleanup() { - rm -f "$main_results" "$current_results" + rm -f "$compare_results" "$current_results" # Restore original branch if we're not on it if [ -n "$current_branch" ] && [ "$(git rev-parse --abbrev-ref HEAD)" != "$current_branch" ]; then git checkout "$current_branch" > /dev/null 2>&1 || true @@ -299,23 +311,23 @@ main() { } trap cleanup EXIT - # Test main branch - echo -e "${YELLOW}=== Testing main branch ===${NC}" + # Test comparison branch + echo -e "${YELLOW}=== Testing ${COMPARE_BRANCH} branch ===${NC}" # Stash any uncommitted changes (only if there are any) if ! git diff-index --quiet HEAD -- 2>/dev/null || ! git diff-index --quiet --cached HEAD -- 2>/dev/null; then git stash push -m "benchmark_compare.sh temporary stash" > /dev/null 2>&1 stashed=true fi - # Restore model file to main branch if we backed it up (so it's available for benchmarking) + # Restore model file to comparison branch if we backed it up (so it's available for benchmarking) if [ -n "$model_backup" ] && [ -f "$model_backup" ]; then mkdir -p "$(dirname "$MODEL_PATH")" cp "$model_backup" "$MODEL_PATH" fi # Use --force to allow overwriting untracked files if needed - git checkout main --force 2>/dev/null || git checkout main - main_commit=$(git rev-parse HEAD) - echo "Commit: ${main_commit}" - run_benchmark "main" "$main_results" + git checkout "$COMPARE_BRANCH" --force 2>/dev/null || git checkout "$COMPARE_BRANCH" + compare_commit=$(git rev-parse HEAD) + echo "Commit: ${compare_commit}" + run_benchmark "$COMPARE_BRANCH" "$compare_results" # Test current branch echo -e "${YELLOW}=== Testing ${current_branch} branch ===${NC}" @@ -334,7 +346,7 @@ main() { run_benchmark "$current_branch" "$current_results" # Generate report - generate_report "$main_results" "$current_results" "$current_branch" "$main_commit" "$current_commit" "$REPORT_FILE" + generate_report "$compare_results" "$current_results" "$current_branch" "$COMPARE_BRANCH" "$compare_commit" "$current_commit" "$REPORT_FILE" echo -e "${GREEN}Benchmark comparison complete!${NC}" } diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 2aa66ec..96960ba 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -3,10 +3,16 @@ #include #include "test/test_activations.cpp" +#include "test/test_conv1d.cpp" +#include "test/test_convnet.cpp" #include "test/test_dsp.cpp" -#include "test/test_get_dsp.cpp" -#include "test/test_wavenet.cpp" #include "test/test_fast_lut.cpp" +#include "test/test_get_dsp.cpp" +#include "test/test_ring_buffer.cpp" +#include "test/test_wavenet/test_layer.cpp" +#include "test/test_wavenet/test_layer_array.cpp" +#include "test/test_wavenet/test_full.cpp" +#include "test/test_wavenet/test_real_time_safe.cpp" #include "test/test_gating_activations.cpp" #include "test/test_wavenet_gating_compatibility.cpp" #include "test/test_blending_detailed.cpp" @@ -25,6 +31,9 @@ int main() test_activations::TestLeakyReLU::test_get_by_init(); test_activations::TestLeakyReLU::test_get_by_str(); + test_lut::TestFastLUT::test_sigmoid(); + test_lut::TestFastLUT::test_tanh(); + test_activations::TestPReLU::test_core_function(); test_activations::TestPReLU::test_per_channel_behavior(); // This is enforced by an assert so it doesn't need to be tested @@ -43,10 +52,57 @@ int main() test_get_dsp::test_null_input_level(); test_get_dsp::test_null_output_level(); - test_wavenet::test_gated(); + test_ring_buffer::test_construct(); + test_ring_buffer::test_reset(); + test_ring_buffer::test_reset_with_receptive_field(); + test_ring_buffer::test_write(); + test_ring_buffer::test_read_with_lookback(); + test_ring_buffer::test_advance(); + test_ring_buffer::test_rewind(); + test_ring_buffer::test_multiple_writes_reads(); + test_ring_buffer::test_reset_zeros_history_area(); + test_ring_buffer::test_rewind_preserves_history(); - test_lut::TestFastLUT::test_sigmoid(); - test_lut::TestFastLUT::test_tanh(); + test_conv1d::test_construct(); + test_conv1d::test_set_size(); + test_conv1d::test_reset(); + test_conv1d::test_process_basic(); + test_conv1d::test_process_with_bias(); + test_conv1d::test_process_multichannel(); + test_conv1d::test_process_dilation(); + test_conv1d::test_process_multiple_calls(); + test_conv1d::test_get_output_different_sizes(); + test_conv1d::test_set_size_and_weights(); + test_conv1d::test_get_num_weights(); + test_conv1d::test_reset_multiple(); + + test_wavenet::test_layer::test_gated(); + test_wavenet::test_layer::test_layer_getters(); + test_wavenet::test_layer::test_non_gated_layer(); + test_wavenet::test_layer::test_layer_activations(); + test_wavenet::test_layer::test_layer_multichannel(); + test_wavenet::test_layer_array::test_layer_array_basic(); + test_wavenet::test_layer_array::test_layer_array_receptive_field(); + test_wavenet::test_layer_array::test_layer_array_with_head_input(); + test_wavenet::test_full::test_wavenet_model(); + test_wavenet::test_full::test_wavenet_multiple_arrays(); + test_wavenet::test_full::test_wavenet_zero_input(); + test_wavenet::test_full::test_wavenet_different_buffer_sizes(); + test_wavenet::test_full::test_wavenet_prewarm(); + test_wavenet::test_allocation_tracking_pass(); + test_wavenet::test_allocation_tracking_fail(); + test_wavenet::test_conv1d_process_realtime_safe(); + test_wavenet::test_layer_process_realtime_safe(); + test_wavenet::test_layer_array_process_realtime_safe(); + test_wavenet::test_process_realtime_safe(); + + test_convnet::test_convnet_basic(); + test_convnet::test_convnet_batchnorm(); + test_convnet::test_convnet_multiple_blocks(); + test_convnet::test_convnet_zero_input(); + test_convnet::test_convnet_different_buffer_sizes(); + test_convnet::test_convnet_prewarm(); + test_convnet::test_convnet_multiple_calls(); // Gating activations tests test_gating_activations::TestGatingActivation::test_basic_functionality(); diff --git a/tools/test/test_conv1d.cpp b/tools/test/test_conv1d.cpp new file mode 100644 index 0000000..3d94e27 --- /dev/null +++ b/tools/test/test_conv1d.cpp @@ -0,0 +1,409 @@ +// Tests for Conv1D + +#include +#include +#include +#include +#include + +#include "NAM/conv1d.h" + +namespace test_conv1d +{ +// Test basic construction +void test_construct() +{ + nam::Conv1D conv; + assert(conv.get_dilation() == 1); + assert(conv.get_in_channels() == 0); + assert(conv.get_out_channels() == 0); + assert(conv.get_kernel_size() == 0); +} + +// Test construction with provided shape +void test_construct_with_shape() +{ + nam::Conv1D conv(2, 3, 5, true, 7); + assert(conv.get_dilation() == 7); + assert(conv.get_in_channels() == 2); + assert(conv.get_out_channels() == 3); + assert(conv.get_kernel_size() == 5); + assert(conv.has_bias()); +} + +// Test set_size_ and getters +void test_set_size() +{ + nam::Conv1D conv; + const int in_channels = 2; + const int out_channels = 4; + const int kernel_size = 3; + const bool do_bias = true; + const int dilation = 2; + + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + assert(conv.get_in_channels() == in_channels); + assert(conv.get_out_channels() == out_channels); + assert(conv.get_kernel_size() == kernel_size); + assert(conv.get_dilation() == dilation); + assert(conv.has_bias() == do_bias); +} + +// Test Reset() initializes buffers +void test_reset() +{ + nam::Conv1D conv; + const int in_channels = 2; + const int out_channels = 4; + const int kernel_size = 3; + const int maxBufferSize = 64; + const double sampleRate = 48000.0; + + conv.set_size_(in_channels, out_channels, kernel_size, false, 1); + conv.SetMaxBufferSize(maxBufferSize); + + // After Reset, GetOutput should work + // (Even thoguh GetOutput() doesn't make sense to call before Process()) + auto output = conv.GetOutput().leftCols(maxBufferSize); + assert(output.rows() == out_channels); + assert(output.cols() == maxBufferSize); +} + +// Test basic Process() with simple convolution +void test_process_basic() +{ + nam::Conv1D conv; + const int in_channels = 1; + const int out_channels = 1; + const int kernel_size = 2; + const bool do_bias = false; + const int dilation = 1; + const int num_frames = 4; + + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + // Set weights: kernel[0] = [[1.0]], kernel[1] = [[2.0]] + // With offset calculation: k=0 has offset=-1 (looks at t-1), k=1 has offset=0 (looks at t) + // So: output = weight[0] * input[t-1] + weight[1] * input[t] = 1.0 * input[t-1] + 2.0 * input[t] + std::vector weights{1.0f, 2.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + + conv.SetMaxBufferSize(64); + + // Create input: [1.0, 2.0, 3.0, 4.0] + Eigen::MatrixXf input(in_channels, num_frames); + input(0, 0) = 1.0f; + input(0, 1) = 2.0f; + input(0, 2) = 3.0f; + input(0, 3) = 4.0f; + + // Process + conv.Process(input, num_frames); + + // Get output + auto output = conv.GetOutput().leftCols(num_frames); + + // Expected outputs (with zero padding for first frame): + // output[0] = 1.0 * 0.0 (zero-padding) + 2.0 * 1.0 = 2.0 + // output[1] = 1.0 * 1.0 + 2.0 * 2.0 = 5.0 + // output[2] = 1.0 * 2.0 + 2.0 * 3.0 = 8.0 + // output[3] = 1.0 * 3.0 + 2.0 * 4.0 = 11.0 + assert(output.rows() == out_channels); + assert(output.cols() == num_frames); + assert(abs(output(0, 0) - 2.0f) < 0.01f); + assert(abs(output(0, 1) - 5.0f) < 0.01f); + assert(abs(output(0, 2) - 8.0f) < 0.01f); + assert(abs(output(0, 3) - 11.0f) < 0.01f); +} + +// Test Process() with bias +void test_process_with_bias() +{ + nam::Conv1D conv; + const int in_channels = 1; + const int out_channels = 1; + const int kernel_size = 2; + const bool do_bias = true; + const int dilation = 1; + const int num_frames = 2; + + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + // Set weights: kernel[0] = [[1.0]], kernel[1] = [[0.0]], bias = [5.0] + // With offset: k=0 has offset=-1 (looks at t-1), k=1 has offset=0 (looks at t) + // So: output = weight[0] * input[t-1] + weight[1] * input[t] + bias = 1.0 * input[t-1] + 0.0 * input[t] + 5.0 + std::vector weights{1.0f, 0.0f, 5.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf input(in_channels, num_frames); + input(0, 0) = 2.0f; + input(0, 1) = 3.0f; + + conv.Process(input, num_frames); + auto output = conv.GetOutput().leftCols(num_frames); + + // Should have bias added + assert(output.rows() == out_channels); + assert(output.cols() == num_frames); + // With zero-padding for first frame: + // First frame: weight[0]*zero + weight[1]*input[0] + bias = 1.0*0.0 + 0.0*2.0 + 5.0 = 5.0 + // Second frame: weight[0]*input[0] + weight[1]*input[1] + bias = 1.0*2.0 + 0.0*3.0 + 5.0 = 7.0 + assert(std::abs(output(0, 0) - 5.0f) < 0.01f); // First frame: zero-padding + bias + assert(std::abs(output(0, 1) - 7.0f) < 0.01f); // Second frame: input[0] + bias +} + +// Test Process() with multiple channels +void test_process_multichannel() +{ + nam::Conv1D conv; + const int in_channels = 2; + const int out_channels = 3; + const int kernel_size = 1; + const bool do_bias = false; + const int dilation = 1; + const int num_frames = 2; + + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + // Set simple identity-like weights for kernel[0] + // weight[0] should be (3, 2) matrix + // Let's use: [[1, 0], [0, 1], [1, 1]] which means: + // out[0] = in[0] + // out[1] = in[1] + // out[2] = in[0] + in[1] + std::vector weights; + // kernel[0] weights (3x2 matrix, row-major flattened) + weights.push_back(1.0f); // out[0], in[0] + weights.push_back(0.0f); // out[0], in[1] + weights.push_back(0.0f); // out[1], in[0] + weights.push_back(1.0f); // out[1], in[1] + weights.push_back(1.0f); // out[2], in[0] + weights.push_back(1.0f); // out[2], in[1] + + auto it = weights.begin(); + conv.set_weights_(it); + + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf input(in_channels, num_frames); + input(0, 0) = 1.0f; + input(1, 0) = 2.0f; + input(0, 1) = 3.0f; + input(1, 1) = 4.0f; + + conv.Process(input, num_frames); + auto output = conv.GetOutput().leftCols(num_frames); + + assert(output.rows() == out_channels); + assert(output.cols() == num_frames); + // out[0] = in[0] = 1.0 + // out[1] = in[1] = 2.0 + // out[2] = in[0] + in[1] = 3.0 + assert(std::abs(output(0, 0) - 1.0f) < 0.01f); + assert(std::abs(output(1, 0) - 2.0f) < 0.01f); + assert(std::abs(output(2, 0) - 3.0f) < 0.01f); +} + +// Test Process() with dilation +void test_process_dilation() +{ + nam::Conv1D conv; + const int in_channels = 1; + const int out_channels = 1; + const int kernel_size = 2; + const bool do_bias = false; + const int dilation = 2; + const int num_frames = 4; + + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + // Set weights: kernel[0] = [[1.0]], kernel[1] = [[2.0]] + // With dilation=2: k=0 has offset=-2 (looks at t-2), k=1 has offset=0 (looks at t) + // So: output = weight[0] * input[t-2] + weight[1] * input[t] = 1.0 * input[t-2] + 2.0 * input[t] + std::vector weights{1.0f, 2.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf input(in_channels, num_frames); + input(0, 0) = 1.0f; + input(0, 1) = 2.0f; + input(0, 2) = 3.0f; + input(0, 3) = 4.0f; + + conv.Process(input, num_frames); + auto output = conv.GetOutput().leftCols(num_frames); + + assert(output.rows() == out_channels); + assert(output.cols() == num_frames); + // Output should be computed correctly with dilation (with zero-padding) + // out[0] = 1.0 * 0.0 (zero-padding) + 2.0 * 1.0 = 2.0 + // out[1] = 1.0 * 0.0 (zero-padding) + 2.0 * 2.0 = 4.0 + // out[2] = 1.0 * 1.0 + 2.0 * 3.0 = 1.0 + 6.0 = 7.0 + // out[3] = 1.0 * 2.0 + 2.0 * 4.0 = 2.0 + 8.0 = 10.0 + assert(abs(output(0, 0) - 2.0f) < 0.01f); + assert(abs(output(0, 1) - 4.0f) < 0.01f); + assert(abs(output(0, 2) - 7.0f) < 0.01f); + assert(abs(output(0, 3) - 10.0f) < 0.01f); +} + +// Test multiple Process() calls (ring buffer functionality) +void test_process_multiple_calls() +{ + nam::Conv1D conv; + const int in_channels = 1; + const int out_channels = 1; + const int kernel_size = 2; + const bool do_bias = false; + const int dilation = 1; + const int num_frames = 2; + + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + // Set weights: kernel[0] = [[1.0]], kernel[1] = [[1.0]] + // With offset: k=0 has offset=-1 (looks at t-1), k=1 has offset=0 (looks at t) + // So: output = weight[0] * input[t-1] + weight[1] * input[t] = input[t-1] + input[t] + std::vector weights{1.0f, 1.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + + conv.SetMaxBufferSize(num_frames); + + // 3 calls should trigger rewind. + Eigen::MatrixXf input(in_channels, num_frames); + input(0, 0) = 1.0f; + input(0, 1) = 2.0f; + for (int i = 0; i < 3; i++) + { + conv.Process(input, num_frames); + } + auto output = conv.GetOutput().leftCols(num_frames); + assert(output.rows() == out_channels); + assert(output.cols() == num_frames); + // After 3 calls, the last call processes input [1, 2] + // It should use history from the previous call (which also had [1, 2]) + // output[0] = weight[0] * (history from previous call's last frame) + weight[1] * input[0] + // = 1.0 * 2.0 + 1.0 * 1.0 = 3.0 + // output[1] = weight[0] * input[0] + weight[1] * input[1] + // = 1.0 * 1.0 + 1.0 * 2.0 = 3.0 + // This tests that ring buffer maintains history across multiple calls + assert(abs(output(0, 0) - 3.0f) < 0.01f); + assert(abs(output(0, 1) - 3.0f) < 0.01f); +} + +// Test GetOutput() with different num_frames +void test_get_output_different_sizes() +{ + nam::Conv1D conv; + const int in_channels = 1; + const int out_channels = 1; + const int kernel_size = 1; + const bool do_bias = false; + const int maxBufferSize = 64; + + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, 1); + + // Identity weight + std::vector weights{1.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + + conv.SetMaxBufferSize(maxBufferSize); + + Eigen::MatrixXf input(in_channels, 4); + input(0, 0) = 1.0f; + input(0, 1) = 2.0f; + input(0, 2) = 3.0f; + input(0, 3) = 4.0f; + + conv.Process(input, 4); + + // Get different sized outputs + auto output_all = conv.GetOutput().leftCols(4); + assert(output_all.cols() == 4); + + auto output_partial = conv.GetOutput().leftCols(2); + assert(output_partial.cols() == 2); + assert(output_partial.rows() == out_channels); +} + +// Test set_size_and_weights_ +void test_set_size_and_weights() +{ + nam::Conv1D conv; + const int in_channels = 1; + const int out_channels = 1; + const int kernel_size = 2; + const bool do_bias = false; + const int dilation = 1; + + std::vector weights{1.0f, 2.0f}; + auto it = weights.begin(); + conv.set_size_and_weights_(in_channels, out_channels, kernel_size, dilation, do_bias, it); + + assert(conv.get_in_channels() == in_channels); + assert(conv.get_out_channels() == out_channels); + assert(conv.get_kernel_size() == kernel_size); + assert(conv.get_dilation() == dilation); + assert(it == weights.end()); // All weights should be consumed +} + +// Test get_num_weights() +void test_get_num_weights() +{ + nam::Conv1D conv; + const int in_channels = 2; + const int out_channels = 3; + const int kernel_size = 2; + const bool do_bias = true; + const int dilation = 1; + + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + // Expected: kernel_size * (out_channels * in_channels) + (bias ? out_channels : 0) + // = 2 * (3 * 2) + 3 = 2 * 6 + 3 = 15 + long expected = kernel_size * (out_channels * in_channels) + out_channels; + long actual = conv.get_num_weights(); + + assert(actual == expected); + + // Test without bias + nam::Conv1D conv_no_bias; + conv_no_bias.set_size_(in_channels, out_channels, kernel_size, false, dilation); + expected = kernel_size * (out_channels * in_channels); + actual = conv_no_bias.get_num_weights(); + assert(actual == expected); +} + +// Test that Reset() can be called multiple times +void test_reset_multiple() +{ + nam::Conv1D conv; + const int in_channels = 1; + const int out_channels = 1; + const int kernel_size = 1; + + conv.set_size_(in_channels, out_channels, kernel_size, false, 1); + + std::vector weights{1.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + + // Reset with different buffer sizes + conv.SetMaxBufferSize(64); + { + auto output1 = conv.GetOutput().leftCols(64); + assert(output1.cols() == 64); + } // output1 goes out of scope here, releasing the block reference + + conv.SetMaxBufferSize(128); + auto output2 = conv.GetOutput().leftCols(128); + assert(output2.cols() == 128); +} +}; // namespace test_conv1d diff --git a/tools/test/test_convnet.cpp b/tools/test/test_convnet.cpp new file mode 100644 index 0000000..ff11074 --- /dev/null +++ b/tools/test/test_convnet.cpp @@ -0,0 +1,275 @@ +// Tests for ConvNet + +#include +#include +#include +#include +#include + +#include "NAM/convnet.h" + +namespace test_convnet +{ +// Test basic ConvNet construction and processing +void test_convnet_basic() +{ + const int channels = 2; + const std::vector dilations{1, 2}; + const bool batchnorm = false; + const std::string activation = "ReLU"; + const double expected_sample_rate = 48000.0; + + // Calculate weights needed: + // Block 0: Conv1D (1, 2, 2, !batchnorm=true, 1) -> 2*1*2 = 4 weights + 2 bias = 6 total + // Block 1: Conv1D (2, 2, 2, !batchnorm=true, 2) -> 2*2*2 = 8 weights + 2 bias = 10 total + // Head: (2, 1) weight + 1 bias = 3 weights + // Total: 6 + 10 + 3 = 19 weights + std::vector weights; + // Block 0 weights (4 weights: kernel[0] and kernel[1], each 2x1) + 2 bias + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + // Block 1 weights (8 weights: kernel[0] and kernel[1], each 2x2) + 2 bias + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + // Head weights (2 weights + 1 bias) + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f}); + + nam::convnet::ConvNet convnet(channels, dilations, batchnorm, activation, weights, expected_sample_rate); + + const int numFrames = 4; + const int maxBufferSize = 64; + convnet.Reset(expected_sample_rate, maxBufferSize); + + std::vector input(numFrames, 1.0f); + std::vector output(numFrames, 0.0f); + + convnet.process(input.data(), output.data(), numFrames); + + // Verify output dimensions + assert(output.size() == numFrames); + // Output should be non-zero and finite + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} + +// Test ConvNet with batchnorm +void test_convnet_batchnorm() +{ + const int channels = 1; + const std::vector dilations{1}; + const bool batchnorm = true; + const std::string activation = "ReLU"; + const double expected_sample_rate = 48000.0; + + // Calculate weights needed: + // Block 0: Conv1D (1, 1, 2, !batchnorm=false, 1) -> 2*1*1 = 2 weights (no bias when batchnorm=true) + // BatchNorm: running_mean(1) + running_var(1) + weight(1) + bias(1) + eps(1) = 5 weights + // Head: (1, 1) weight + 1 bias = 2 weights + // Total: 2 + 5 + 2 = 9 weights + std::vector weights; + // Block 0 weights (2 weights: kernel[0], kernel[1], no bias) + weights.insert(weights.end(), {1.0f, 1.0f}); + // BatchNorm weights (5: mean, var, weight, bias, eps) + weights.insert(weights.end(), {0.0f, 1.0f, 1.0f, 0.0f, 1e-5f}); + // Head weights (1 weight + 1 bias) + weights.insert(weights.end(), {1.0f, 0.0f}); + + nam::convnet::ConvNet convnet(channels, dilations, batchnorm, activation, weights, expected_sample_rate); + + const int numFrames = 4; + const int maxBufferSize = 64; + convnet.Reset(expected_sample_rate, maxBufferSize); + + std::vector input(numFrames, 1.0f); + std::vector output(numFrames, 0.0f); + + convnet.process(input.data(), output.data(), numFrames); + + assert(output.size() == numFrames); + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} + +// Test ConvNet with multiple blocks +void test_convnet_multiple_blocks() +{ + const int channels = 2; + const std::vector dilations{1, 2, 4}; + const bool batchnorm = false; + const std::string activation = "Tanh"; + const double expected_sample_rate = 48000.0; + + // Calculate weights needed: + // Block 0: Conv1D (1, 2, 2, !batchnorm=true, 1) -> 2*1*2 = 4 weights + 2 bias = 6 total + // Block 1: Conv1D (2, 2, 2, !batchnorm=true, 2) -> 2*2*2 = 8 weights + 2 bias = 10 total + // Block 2: Conv1D (2, 2, 2, !batchnorm=true, 4) -> 2*2*2 = 8 weights + 2 bias = 10 total + // Head: (2, 1) weight + 1 bias = 3 weights + // Total: 6 + 10 + 10 + 3 = 29 weights + std::vector weights; + // Block 0 weights (4 weights + 2 bias) + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + // Block 1 weights (8 weights + 2 bias) + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + // Block 2 weights (8 weights + 2 bias) + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + // Head weights + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f}); + + nam::convnet::ConvNet convnet(channels, dilations, batchnorm, activation, weights, expected_sample_rate); + + const int numFrames = 8; + const int maxBufferSize = 64; + convnet.Reset(expected_sample_rate, maxBufferSize); + + std::vector input(numFrames, 0.5f); + std::vector output(numFrames, 0.0f); + + convnet.process(input.data(), output.data(), numFrames); + + assert(output.size() == numFrames); + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} + +// Test ConvNet with zero input +void test_convnet_zero_input() +{ + const int channels = 1; + const std::vector dilations{1}; + const bool batchnorm = false; + const std::string activation = "ReLU"; + const double expected_sample_rate = 48000.0; + + std::vector weights; + // Block 0 weights (2 weights: kernel[0], kernel[1] + 1 bias, since batchnorm=false) + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f}); + // Head weights (1 weight + 1 bias) + weights.insert(weights.end(), {1.0f, 0.0f}); + + nam::convnet::ConvNet convnet(channels, dilations, batchnorm, activation, weights, expected_sample_rate); + + const int numFrames = 4; + convnet.Reset(expected_sample_rate, numFrames); + + std::vector input(numFrames, 0.0f); + std::vector output(numFrames, 0.0f); + + convnet.process(input.data(), output.data(), numFrames); + + // With zero input, output should be finite (may be zero or non-zero depending on bias) + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} + +// Test ConvNet with different buffer sizes +void test_convnet_different_buffer_sizes() +{ + const int channels = 1; + const std::vector dilations{1}; + const bool batchnorm = false; + const std::string activation = "ReLU"; + const double expected_sample_rate = 48000.0; + + std::vector weights; + // Block 0 weights (2 weights: kernel[0], kernel[1] + 1 bias, since batchnorm=false) + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f}); + // Head weights (1 weight + 1 bias) + weights.insert(weights.end(), {1.0f, 0.0f}); + + nam::convnet::ConvNet convnet(channels, dilations, batchnorm, activation, weights, expected_sample_rate); + + // Test with different buffer sizes + convnet.Reset(expected_sample_rate, 64); + std::vector input1(32, 1.0f); + std::vector output1(32, 0.0f); + convnet.process(input1.data(), output1.data(), 32); + + convnet.Reset(expected_sample_rate, 128); + std::vector input2(64, 1.0f); + std::vector output2(64, 0.0f); + convnet.process(input2.data(), output2.data(), 64); + + // Both should work without errors + assert(output1.size() == 32); + assert(output2.size() == 64); +} + +// Test ConvNet prewarm functionality +void test_convnet_prewarm() +{ + const int channels = 2; + const std::vector dilations{1, 2, 4}; + const bool batchnorm = false; + const std::string activation = "ReLU"; + const double expected_sample_rate = 48000.0; + + std::vector weights; + // Block 0 weights (4 weights + 2 bias, since batchnorm=false) + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + // Block 1 weights (8 weights + 2 bias, since batchnorm=false) + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + // Block 2 weights (8 weights + 2 bias, since batchnorm=false) + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f}); + // Head weights (2 weights + 1 bias) + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f}); + + nam::convnet::ConvNet convnet(channels, dilations, batchnorm, activation, weights, expected_sample_rate); + + // Test that prewarm can be called without errors + convnet.Reset(expected_sample_rate, 64); + convnet.prewarm(); + + // After prewarm, processing should work + const int numFrames = 4; + std::vector input(numFrames, 1.0f); + std::vector output(numFrames, 0.0f); + convnet.process(input.data(), output.data(), numFrames); + + // Output should be finite + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} + +// Test multiple process() calls (ring buffer functionality) +void test_convnet_multiple_calls() +{ + const int channels = 1; + const std::vector dilations{1}; + const bool batchnorm = false; + const std::string activation = "ReLU"; + const double expected_sample_rate = 48000.0; + + std::vector weights; + // Block 0 weights (2 weights: kernel[0], kernel[1] + 1 bias, since batchnorm=false) + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f}); + // Head weights (1 weight + 1 bias) + weights.insert(weights.end(), {1.0f, 0.0f}); + + nam::convnet::ConvNet convnet(channels, dilations, batchnorm, activation, weights, expected_sample_rate); + + const int numFrames = 2; + convnet.Reset(expected_sample_rate, numFrames); + + // Multiple calls should work correctly with ring buffer + for (int i = 0; i < 5; i++) + { + std::vector input(numFrames, 1.0f); + std::vector output(numFrames, 0.0f); + convnet.process(input.data(), output.data(), numFrames); + + // Output should be finite + for (int j = 0; j < numFrames; j++) + { + assert(std::isfinite(output[j])); + } + } +} +}; // namespace test_convnet diff --git a/tools/test/test_ring_buffer.cpp b/tools/test/test_ring_buffer.cpp new file mode 100644 index 0000000..3166928 --- /dev/null +++ b/tools/test/test_ring_buffer.cpp @@ -0,0 +1,323 @@ +// Tests for RingBuffer + +#include +#include +#include +#include +#include + +#include "NAM/ring_buffer.h" + +namespace test_ring_buffer +{ +// Test basic construction +void test_construct() +{ + nam::RingBuffer rb; + assert(rb.GetMaxBufferSize() == 0); + assert(rb.GetChannels() == 0); +} + +// Test Reset() initializes storage correctly +void test_reset() +{ + nam::RingBuffer rb; + const int channels = 2; + const int max_buffer_size = 64; + + rb.Reset(channels, max_buffer_size); + + assert(rb.GetChannels() == channels); + assert(rb.GetMaxBufferSize() == max_buffer_size); +} + +// Test Reset() with max_lookback zeros the storage behind starting position +void test_reset_with_receptive_field() +{ + nam::RingBuffer rb; + const int channels = 2; + const int max_buffer_size = 64; + const long max_lookback = 10; + + rb.SetMaxLookback(max_lookback); + rb.Reset(channels, max_buffer_size); + + assert(rb.GetChannels() == channels); + assert(rb.GetMaxBufferSize() == max_buffer_size); + + // The storage behind the starting position should be zero + // Read from position 0 by using lookback = max_lookback (read_pos = write_pos - max_lookback = 0) + auto buffer_block = rb.Read(max_lookback, max_lookback); // Read from position 0 + assert(buffer_block.isZero()); +} + +// Test Write() writes data at write position +void test_write() +{ + nam::RingBuffer rb; + const int channels = 2; + const int max_buffer_size = 64; + const int num_frames = 4; + const long max_lookback = 4; + + rb.SetMaxLookback(max_lookback); + rb.Reset(channels, max_buffer_size); + + Eigen::MatrixXf input(channels, num_frames); + input(0, 0) = 1.0f; + input(1, 0) = 2.0f; + input(0, 1) = 3.0f; + input(1, 1) = 4.0f; + input(0, 2) = 5.0f; + input(1, 2) = 6.0f; + input(0, 3) = 7.0f; + input(1, 3) = 8.0f; + + rb.Write(input, num_frames); + + // Read back what we just wrote (with lookback=0, since write_pos hasn't advanced) + auto output = rb.Read(num_frames, 0); + assert(output.rows() == channels); + assert(output.cols() == num_frames); + assert(std::abs(output(0, 0) - 1.0f) < 0.01f); + assert(std::abs(output(1, 0) - 2.0f) < 0.01f); + assert(std::abs(output(0, 1) - 3.0f) < 0.01f); + assert(std::abs(output(1, 1) - 4.0f) < 0.01f); + + // After Advance, we need lookback to read what we wrote + rb.Advance(num_frames); + auto output_after_advance = rb.Read(num_frames, num_frames); + assert(std::abs(output_after_advance(0, 0) - 1.0f) < 0.01f); + assert(std::abs(output_after_advance(1, 1) - 4.0f) < 0.01f); +} + +// Test Read() with lookback +void test_read_with_lookback() +{ + nam::RingBuffer rb; + const int channels = 1; + const int max_buffer_size = 64; + const long max_lookback = 5; + + rb.SetMaxLookback(max_lookback); + rb.Reset(channels, max_buffer_size); + + // Write some data + Eigen::MatrixXf input1(channels, 3); + input1(0, 0) = 1.0f; + input1(0, 1) = 2.0f; + input1(0, 2) = 3.0f; + + rb.Write(input1, 3); + rb.Advance(3); + + // Write more data + Eigen::MatrixXf input2(channels, 2); + input2(0, 0) = 4.0f; + input2(0, 1) = 5.0f; + + rb.Write(input2, 2); + + // After Write(), data is at write_pos but write_pos hasn't advanced yet + // So lookback=0 reads from write_pos, which has the data we just wrote + auto current = rb.Read(2, 0); + assert(std::abs(current(0, 0) - 4.0f) < 0.01f); + assert(std::abs(current(0, 1) - 5.0f) < 0.01f); + + // After Advance(3), write_pos = receptive_field + 3 = 8 + // Read with lookback=2 should get the last 2 frames from input1 + auto recent = rb.Read(2, 2); + // Position 8-2=6 has input1[1]=2.0, position 7 has input1[2]=3.0 + assert(std::abs(recent(0, 0) - 2.0f) < 0.01f); // input1[1] at position 6 + assert(std::abs(recent(0, 1) - 3.0f) < 0.01f); // input1[2] at position 7 + + rb.Advance(2); // Now write_pos = 10 + + // Read with lookback=2 to get input2 we just wrote + auto input2_read = rb.Read(2, 2); + assert(std::abs(input2_read(0, 0) - 4.0f) < 0.01f); + assert(std::abs(input2_read(0, 1) - 5.0f) < 0.01f); + + // Read with lookback=5 (should get frames from first write) + auto history = rb.Read(2, 5); + // Position 10-5=5 has input1[0]=1.0, position 6 has input1[1]=2.0 + assert(std::abs(history(0, 0) - 1.0f) < 0.01f); // input1[0] + assert(std::abs(history(0, 1) - 2.0f) < 0.01f); // input1[1] +} + +// Test Advance() moves write pointer +void test_advance() +{ + nam::RingBuffer rb; + const int channels = 1; + const int max_buffer_size = 64; + const long max_lookback = 15; + + rb.SetMaxLookback(max_lookback); + rb.Reset(channels, max_buffer_size); + + // Test that Advance() works by writing, advancing, and reading back + Eigen::MatrixXf input(channels, 10); + input.setZero(); + input(0, 0) = 1.0f; + rb.Write(input, 10); + rb.Advance(10); + + // Read back with lookback to verify advance worked + auto output = rb.Read(10, 10); + assert(std::abs(output(0, 0) - 1.0f) < 0.01f); + + rb.Advance(5); + // Read back with larger lookback to verify further advance + auto output2 = rb.Read(10, 15); + assert(std::abs(output2(0, 0) - 1.0f) < 0.01f); +} + +// Test Rewind() copies history and resets write position +void test_rewind() +{ + nam::RingBuffer rb; + const int channels = 1; + const int max_buffer_size = 32; + const long max_lookback = 5; + + rb.SetMaxLookback(max_lookback); + rb.Reset(channels, max_buffer_size); + + // Storage size = 2 * max_lookback + max_buffer_size = 2 * 5 + 32 = 42 + const long storage_size = 2 * max_lookback + max_buffer_size; + + // Write enough data to trigger rewind + // We need to write more than storage_size to trigger rewind + const int num_writes = 25; // 25 * 2 = 50 > 42 + long writeSize = 2; + assert(writeSize * num_writes > storage_size); + for (int i = 0; i < num_writes; i++) + { + Eigen::MatrixXf input(channels, writeSize); + input(0, 0) = (float)(i * 2); + input(0, 1) = (float)(i * 2 + 1); + + rb.Write(input, writeSize); + rb.Advance(writeSize); + + // Continue writing until we've written enough to potentially trigger rewind + // The rewind will happen automatically in Write() if needed + } + + // [SDA] this next part is an AI test and I'm not sure I like it. + + // After writing enough data, we should be able to read from history + // Read with lookback = max_lookback to read from position 0 (history region) + auto history = rb.Read(2, max_lookback); + // History should be available + assert(history.cols() == 2); +} + + +// Test multiple writes and reads maintain history correctly +void test_multiple_writes_reads() +{ + nam::RingBuffer rb; + const int channels = 1; + const int max_buffer_size = 64; + const long max_lookback = 5; + + rb.SetMaxLookback(max_lookback); + rb.Reset(channels, max_buffer_size); + + // Write first batch + Eigen::MatrixXf input1(channels, 3); + input1(0, 0) = 1.0f; + input1(0, 1) = 2.0f; + input1(0, 2) = 3.0f; + + rb.Write(input1, 3); + rb.Advance(3); + + // Write second batch + Eigen::MatrixXf input2(channels, 2); + input2(0, 0) = 4.0f; + input2(0, 1) = 5.0f; + + rb.Write(input2, 2); + rb.Advance(2); + + // After Write() and Advance(), write_pos points after the data we just wrote + // Read with lookback=2 to get the last 2 frames we wrote (input2) + auto current = rb.Read(2, 2); + assert(std::abs(current(0, 0) - 4.0f) < 0.01f); + assert(std::abs(current(0, 1) - 5.0f) < 0.01f); + + // Read with lookback=5 should get frames from first batch (input1[1] and input1[2]) + // After writes: input1 at positions [max_lookback, max_lookback+2] = [3, 4, 5] + // input2 at positions [max_lookback+3, max_lookback+4] = [6, 7] + // write_pos after both: max_lookback + 5 = 8 + // Read with lookback=5: read_pos = 8 - 5 = 3 + // This reads from position 3, which is input1[0] = 1.0 + auto history = rb.Read(2, 5); + // Position 3 = input1[0] = 1.0, position 4 = input1[1] = 2.0 + assert(std::abs(history(0, 0) - 1.0f) < 0.01f); + assert(std::abs(history(0, 1) - 2.0f) < 0.01f); +} + +// Test that Reset() zeros buffer behind starting position +void test_reset_zeros_history_area() +{ + nam::RingBuffer rb; + const int channels = 1; + const int max_buffer_size = 64; + const long max_lookback = 10; + + rb.SetMaxLookback(max_lookback); + rb.Reset(channels, max_buffer_size); + + // Write some data and advance + Eigen::MatrixXf input(channels, max_buffer_size); + input.fill(42.0f); + for (int i = 0; i < 5; i++) // Should be enough to write those first positions. + { + rb.Write(input, max_buffer_size); + rb.Advance(max_buffer_size); + } + + // Reset should zero the storage behind the starting position + rb.Reset(channels, max_buffer_size); + + // Read from position 0 (behind starting write position) + // This should be zero + // After Reset with max_lookback, we can read from position 0 + auto read = rb.Read(max_lookback, max_lookback); + assert(read.isZero()); +} + +// Test Rewind() preserves history correctly +void test_rewind_preserves_history() +{ + nam::RingBuffer rb; + const int channels = 1; + const int max_buffer_size = 32; + const long max_lookback = 4; + + rb.SetMaxLookback(max_lookback); + rb.Reset(channels, max_buffer_size); + + // Storage size = 2 * max_lookback + max_buffer_size = 2 * 4 + 32 = 40 + const long storage_size = 2 * max_lookback + max_buffer_size; + + // Three writes of size max_buffer_size should trigger rewind. + Eigen::MatrixXf input(channels, max_buffer_size); + input.fill(42.0f); + for (int i = 0; i < 3; i++) + { + rb.Write(input, max_buffer_size); + rb.Advance(max_buffer_size); + } + + // Read from history region to verify rewind preserved history + auto history = rb.Read(max_lookback, max_lookback); + assert(history.cols() == max_lookback); + assert(history == input.rightCols(max_lookback)); +} + +}; // namespace test_ring_buffer diff --git a/tools/test/test_wavenet.cpp b/tools/test/test_wavenet.cpp deleted file mode 100644 index 87390cd..0000000 --- a/tools/test/test_wavenet.cpp +++ /dev/null @@ -1,76 +0,0 @@ -// Tests for the WaveNet - -#include -#include -#include - -#include "NAM/wavenet.h" - -namespace test_wavenet -{ -void test_gated() -{ - // Assert correct nuemrics of the gating activation. - // Issue 101 - const int conditionSize = 1; - const int channels = 1; - const int kernelSize = 1; - const int dilation = 1; - const std::string activation = "ReLU"; - const bool gated = true; - auto layer = nam::wavenet::_Layer(conditionSize, channels, kernelSize, dilation, activation, gated); - - // Conv, input mixin, 1x1 - std::vector weights{ - // Conv (weight, bias) NOTE: 2 channels out bc gated, so shapes are (2,1,1), (2,) - 1.0f, 1.0f, 0.0f, 0.0f, - // Input mixin (weight only: (2,1,1)) - 1.0f, -1.0f, - // 1x1 (weight (1,1,1), bias (1,)) - // NOTE: Weights are (1,1) on conv, (1,-1), so the inputs sum on the upper channel and cancel on the lower. - // This should give us a nice zero if the input & condition are the same, so that'll sigmoid to 0.5 for the - // gate. - 1.0f, 0.0f}; - auto it = weights.begin(); - layer.set_weights_(it); - assert(it == weights.end()); - - const long numFrames = 4; - layer.SetMaxBufferSize(numFrames); - - Eigen::MatrixXf input, condition, headInput, output; - input.resize(channels, numFrames); - condition.resize(channels, numFrames); - headInput.resize(channels, numFrames); - output.resize(channels, numFrames); - - const float signalValue = 0.25f; - input.fill(signalValue); - condition.fill(signalValue); - // So input & condition will sum to 0.5 on the top channel (-> ReLU), cancel to 0 on bottom (-> sigmoid) - - headInput.setZero(); - output.setZero(); - - layer.process_(input, condition, headInput, output, 0, 0, (int)numFrames); - - // 0.25 + 0.25 -> 0.5 for conv & input mixin top channel - // (0 on bottom channel) - // Top ReLU -> preseves 0.5 - // Bottom sigmoid 0->0.5 - // Product is 0.25 - // 1x1 is unity - // Skip-connect -> 0.25 (input) + 0.25 (output) -> 0.5 output - // head output gets 0+0.25 = 0.25 - const float expectedOutput = 0.5; - const float expectedHeadInput = 0.25; - for (int i = 0; i < numFrames; i++) - { - const float actualOutput = output(0, i); - const float actualHeadInput = headInput(0, i); - // std::cout << actualOutput << std::endl; - assert(actualOutput == expectedOutput); - assert(actualHeadInput == expectedHeadInput); - } -} -}; // namespace test_wavenet \ No newline at end of file diff --git a/tools/test/test_wavenet/test_full.cpp b/tools/test/test_wavenet/test_full.cpp new file mode 100644 index 0000000..8a72ab3 --- /dev/null +++ b/tools/test/test_wavenet/test_full.cpp @@ -0,0 +1,255 @@ +// Tests for full WaveNet model + +#include +#include +#include +#include +#include + +#include "NAM/wavenet.h" + +namespace test_wavenet +{ +namespace test_full +{ +// Test full WaveNet model +void test_wavenet_model() +{ + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 1; + std::vector dilations{1}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + const float head_scale = 1.0f; + const bool with_head = false; + + nam::wavenet::LayerArrayParams params( + input_size, condition_size, head_size, channels, kernel_size, std::move(dilations), activation, gated, head_bias); + std::vector layer_array_params; + layer_array_params.push_back(std::move(params)); + + // Calculate weights needed + // Layer array 0: + // Rechannel: (1,1) weight + // Layer 0: conv (1,1,1) + bias, input_mixin (1,1), 1x1 (1,1) + bias + // Head rechannel: (1,1) weight + // Head scale: 1 float + std::vector weights; + weights.push_back(1.0f); // Rechannel + weights.insert(weights.end(), {1.0f, 0.0f, 1.0f, 1.0f, 0.0f}); // Layer 0 + weights.push_back(1.0f); // Head rechannel + weights.push_back(head_scale); // Head scale + + auto wavenet = std::make_unique(layer_array_params, head_scale, with_head, weights, 48000.0); + + const int numFrames = 4; + const int maxBufferSize = 64; + wavenet->Reset(48000.0, maxBufferSize); + + std::vector input(numFrames, 1.0f); + std::vector output(numFrames, 0.0f); + + wavenet->process(input.data(), output.data(), numFrames); + + // Verify output dimensions + assert(output.size() == numFrames); + // Output should be non-zero + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} + +// Test WaveNet with multiple layer arrays +void test_wavenet_multiple_arrays() +{ + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 1; + std::vector dilations{1}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + const float head_scale = 0.5f; + const bool with_head = false; + + std::vector layer_array_params; + // First array + layer_array_params.emplace_back( + input_size, condition_size, head_size, channels, kernel_size, std::vector{1}, activation, gated, head_bias); + // Second array (head_size of first must match channels of second) + layer_array_params.emplace_back( + head_size, condition_size, head_size, channels, kernel_size, std::vector{1}, activation, gated, head_bias); + + std::vector weights; + // Array 0: rechannel, layer, head_rechannel + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f}); + // Array 1: rechannel, layer, head_rechannel + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f}); + weights.push_back(head_scale); + + auto wavenet = std::make_unique(layer_array_params, head_scale, with_head, weights, 48000.0); + + const int numFrames = 4; + const int maxBufferSize = 64; + wavenet->Reset(48000.0, maxBufferSize); + + std::vector input(numFrames, 1.0f); + std::vector output(numFrames, 0.0f); + + wavenet->process(input.data(), output.data(), numFrames); + + assert(output.size() == numFrames); + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} + +// Test WaveNet with zero input +void test_wavenet_zero_input() +{ + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 1; + std::vector dilations{1}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + const float head_scale = 1.0f; + const bool with_head = false; + + nam::wavenet::LayerArrayParams params( + input_size, condition_size, head_size, channels, kernel_size, std::move(dilations), activation, gated, head_bias); + std::vector layer_array_params; + layer_array_params.push_back(std::move(params)); + + std::vector weights{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, head_scale}; + + auto wavenet = std::make_unique(layer_array_params, head_scale, with_head, weights, 48000.0); + + const int numFrames = 4; + wavenet->Reset(48000.0, numFrames); + + std::vector input(numFrames, 0.0f); + std::vector output(numFrames, 0.0f); + + wavenet->process(input.data(), output.data(), numFrames); + + // With zero input, output should be finite (may be zero or non-zero depending on bias) + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} + +// Test WaveNet with different buffer sizes +void test_wavenet_different_buffer_sizes() +{ + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 1; + std::vector dilations{1}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + const float head_scale = 1.0f; + const bool with_head = false; + + nam::wavenet::LayerArrayParams params( + input_size, condition_size, head_size, channels, kernel_size, std::move(dilations), activation, gated, head_bias); + std::vector layer_array_params; + layer_array_params.push_back(std::move(params)); + + std::vector weights{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, head_scale}; + + auto wavenet = std::make_unique(layer_array_params, head_scale, with_head, weights, 48000.0); + + // Test with different buffer sizes + wavenet->Reset(48000.0, 64); + std::vector input1(32, 1.0f); + std::vector output1(32, 0.0f); + wavenet->process(input1.data(), output1.data(), 32); + + wavenet->Reset(48000.0, 128); + std::vector input2(64, 1.0f); + std::vector output2(64, 0.0f); + wavenet->process(input2.data(), output2.data(), 64); + + // Both should work without errors + assert(output1.size() == 32); + assert(output2.size() == 64); +} + +// Test WaveNet prewarm functionality +void test_wavenet_prewarm() +{ + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 3; + std::vector dilations{1, 2, 4}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + const float head_scale = 1.0f; + const bool with_head = false; + + nam::wavenet::LayerArrayParams params( + input_size, condition_size, head_size, channels, kernel_size, std::move(dilations), activation, gated, head_bias); + std::vector layer_array_params; + layer_array_params.push_back(std::move(params)); + + std::vector weights; + // Rechannel: (1,1) weight, no bias + weights.push_back(1.0f); + // 3 layers: each needs: + // Conv: kernel_size=3, in_channels=1, out_channels=1, bias=true -> 3*1*1 + 1 = 4 weights + // Input mixin: condition_size=1, out_channels=1, no bias -> 1 weight + // 1x1: in_channels=1, out_channels=1, bias=true -> 1*1 + 1 = 2 weights + // Total per layer: 7 weights + for (int i = 0; i < 3; i++) + { + // Conv weights: 3 weights (kernel_size * in_channels * out_channels) + 1 bias + weights.insert(weights.end(), {1.0f, 1.0f, 1.0f, 0.0f}); + // Input mixin: 1 weight + weights.push_back(1.0f); + // 1x1: 1 weight + 1 bias + weights.insert(weights.end(), {1.0f, 0.0f}); + } + // Head rechannel: (1,1) weight, no bias + weights.push_back(1.0f); + weights.push_back(head_scale); + + auto wavenet = std::make_unique(layer_array_params, head_scale, with_head, weights, 48000.0); + + // Test that prewarm can be called without errors + wavenet->Reset(48000.0, 64); + wavenet->prewarm(); + + // After prewarm, processing should work + const int numFrames = 4; + std::vector input(numFrames, 1.0f); + std::vector output(numFrames, 0.0f); + wavenet->process(input.data(), output.data(), numFrames); + + // Output should be finite + for (int i = 0; i < numFrames; i++) + { + assert(std::isfinite(output[i])); + } +} +}; // namespace test_full + +} // namespace test_wavenet diff --git a/tools/test/test_wavenet/test_layer.cpp b/tools/test/test_wavenet/test_layer.cpp new file mode 100644 index 0000000..40f9439 --- /dev/null +++ b/tools/test/test_wavenet/test_layer.cpp @@ -0,0 +1,265 @@ +// Tests for WaveNet Layer + +#include +#include +#include +#include +#include + +#include "NAM/wavenet.h" + +namespace test_wavenet +{ +namespace test_layer +{ +void test_gated() +{ + // Assert correct nuemrics of the gating activation. + // Issue 101 + const int conditionSize = 1; + const int channels = 1; + const int kernelSize = 1; + const int dilation = 1; + const std::string activation = "ReLU"; + const bool gated = true; + auto layer = nam::wavenet::_Layer(conditionSize, channels, kernelSize, dilation, activation, gated); + + // Conv, input mixin, 1x1 + std::vector weights{ + // Conv (weight, bias) NOTE: 2 channels out bc gated, so shapes are (2,1,1), (2,) + 1.0f, 1.0f, 0.0f, 0.0f, + // Input mixin (weight only: (2,1,1)) + 1.0f, -1.0f, + // 1x1 (weight (1,1,1), bias (1,)) + // NOTE: Weights are (1,1) on conv, (1,-1), so the inputs sum on the upper channel and cancel on the lower. + // This should give us a nice zero if the input & condition are the same, so that'll sigmoid to 0.5 for the + // gate. + 1.0f, 0.0f}; + auto it = weights.begin(); + layer.set_weights_(it); + assert(it == weights.end()); + + const long numFrames = 4; + layer.SetMaxBufferSize(numFrames); + + Eigen::MatrixXf input, condition, headInput, output; + input.resize(channels, numFrames); + condition.resize(conditionSize, numFrames); + headInput.resize(channels, numFrames); + output.resize(channels, numFrames); + + const float signalValue = 0.25f; + input.fill(signalValue); + condition.fill(signalValue); + // So input & condition will sum to 0.5 on the top channel (-> ReLU), cancel to 0 on bottom (-> sigmoid) + + headInput.setZero(); + output.setZero(); + + layer.Process(input, condition, (int)numFrames); + // Get outputs + auto layer_output = layer.GetOutputNextLayer().leftCols((int)numFrames); + auto head_output = layer.GetOutputHead().leftCols((int)numFrames); + // Copy to test buffers for verification + output.leftCols((int)numFrames) = layer_output; + headInput.leftCols((int)numFrames) = head_output; + + // 0.25 + 0.25 -> 0.5 for conv & input mixin top channel + // (0 on bottom channel) + // Top ReLU -> preseves 0.5 + // Bottom sigmoid 0->0.5 + // Product is 0.25 + // 1x1 is unity + // Skip-connect -> 0.25 (input) + 0.25 (output) -> 0.5 output + // head output gets 0+0.25 = 0.25 + const float expectedOutput = 0.5; + const float expectedHeadInput = 0.25; + for (int i = 0; i < numFrames; i++) + { + const float actualOutput = output(0, i); + const float actualHeadInput = headInput(0, i); + // std::cout << actualOutput << std::endl; + assert(actualOutput == expectedOutput); + assert(actualHeadInput == expectedHeadInput); + } +} + +// Test layer getters +void test_layer_getters() +{ + const int conditionSize = 2; + const int channels = 4; + const int kernelSize = 3; + const int dilation = 2; + const std::string activation = "Tanh"; + const bool gated = false; + + auto layer = nam::wavenet::_Layer(conditionSize, channels, kernelSize, dilation, activation, gated); + + assert(layer.get_channels() == channels); + assert(layer.get_kernel_size() == kernelSize); + assert(layer.get_dilation() == dilation); +} + +// Test non-gated layer processing +void test_non_gated_layer() +{ + const int conditionSize = 1; + const int channels = 1; + const int kernelSize = 1; + const int dilation = 1; + const std::string activation = "ReLU"; + const bool gated = false; + + auto layer = nam::wavenet::_Layer(conditionSize, channels, kernelSize, dilation, activation, gated); + + // For non-gated: conv outputs 1 channel, input_mixin outputs 1 channel, 1x1 outputs 1 channel + // Conv: (1,1,1) weight + (1,) bias + // Input mixin: (1,1) weight (no bias) + // 1x1: (1,1) weight + (1,) bias + std::vector weights{// Conv: weight=1.0, bias=0.0 + 1.0f, 0.0f, + // Input mixin: weight=1.0 + 1.0f, + // 1x1: weight=1.0, bias=0.0 + 1.0f, 0.0f}; + + auto it = weights.begin(); + layer.set_weights_(it); + assert(it == weights.end()); + + const int numFrames = 4; + layer.SetMaxBufferSize(numFrames); + + Eigen::MatrixXf input(channels, numFrames); + Eigen::MatrixXf condition(conditionSize, numFrames); + input.fill(1.0f); + condition.fill(1.0f); + + layer.Process(input, condition, numFrames); + + auto layer_output = layer.GetOutputNextLayer().leftCols(numFrames); + auto head_output = layer.GetOutputHead().leftCols(numFrames); + + assert(layer_output.rows() == channels); + assert(layer_output.cols() == numFrames); + assert(head_output.rows() == channels); + assert(head_output.cols() == numFrames); + + // With identity-like weights: input=1, condition=1 + // conv output = 1*1 + 0 = 1 + // input_mixin output = 1*1 = 1 + // z = 1 + 1 = 2 + // ReLU(2) = 2 + // 1x1 output = 1*2 + 0 = 2 + // layer_output = input + 1x1_output = 1 + 2 = 3 + // head_output = activated z = 2 + const float expectedLayerOutput = 3.0f; + const float expectedHeadOutput = 2.0f; + for (int i = 0; i < numFrames; i++) + { + assert(std::abs(layer_output(0, i) - expectedLayerOutput) < 0.01f); + assert(std::abs(head_output(0, i) - expectedHeadOutput) < 0.01f); + } +} + +// Test layer with different activations +void test_layer_activations() +{ + const int conditionSize = 1; + const int channels = 1; + const int kernelSize = 1; + const int dilation = 1; + const bool gated = false; + + // Test Tanh activation + { + auto layer = nam::wavenet::_Layer(conditionSize, channels, kernelSize, dilation, "Tanh", gated); + std::vector weights{1.0f, 0.0f, 1.0f, 1.0f, 0.0f}; + auto it = weights.begin(); + layer.set_weights_(it); + + const int numFrames = 2; + layer.SetMaxBufferSize(numFrames); + + Eigen::MatrixXf input(channels, numFrames); + Eigen::MatrixXf condition(conditionSize, numFrames); + input.fill(0.5f); + condition.fill(0.5f); + + layer.Process(input, condition, numFrames); + auto head_output = layer.GetOutputHead().leftCols(numFrames); + + // Should have applied Tanh activation, so output should be between -1 and 1. + assert(head_output(0, 0) <= 1.0f); + assert(head_output(0, 0) >= -1.0f); + } +} + +// Test layer with multiple channels +void test_layer_multichannel() +{ + const int conditionSize = 2; + const int channels = 3; + const int kernelSize = 1; + const int dilation = 1; + const std::string activation = "ReLU"; + const bool gated = false; + + auto layer = nam::wavenet::_Layer(conditionSize, channels, kernelSize, dilation, activation, gated); + + assert(layer.get_channels() == channels); + + const int numFrames = 2; + layer.SetMaxBufferSize(numFrames); + + // Set identity-like weights (simplified) + // Conv: (3,3,1) weights + (3,) bias + // Input mixin: (3,2) weights + // 1x1: (3,3) weights + (3,) bias + std::vector weights; + // Conv weights: 3x3 identity matrix flattened + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + weights.push_back((i == j) ? 1.0f : 0.0f); + } + } + // Conv bias: zeros + weights.insert(weights.end(), {0.0f, 0.0f, 0.0f}); + // Input mixin: (3,2) zeros + weights.insert(weights.end(), {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + // 1x1: (3,3) identity + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + weights.push_back((i == j) ? 1.0f : 0.0f); + } + } + // 1x1 bias: zeros + weights.insert(weights.end(), {0.0f, 0.0f, 0.0f}); + + auto it = weights.begin(); + layer.set_weights_(it); + assert(it == weights.end()); + + Eigen::MatrixXf input(channels, numFrames); + Eigen::MatrixXf condition(conditionSize, numFrames); + input.fill(1.0f); + condition.fill(1.0f); + + layer.Process(input, condition, numFrames); + + auto layer_output = layer.GetOutputNextLayer().leftCols(numFrames); + auto head_output = layer.GetOutputHead().leftCols(numFrames); + + assert(layer_output.rows() == channels); + assert(layer_output.cols() == numFrames); + assert(head_output.rows() == channels); + assert(head_output.cols() == numFrames); +} +}; // namespace test_layer + +} // namespace test_wavenet \ No newline at end of file diff --git a/tools/test/test_wavenet/test_layer_array.cpp b/tools/test/test_wavenet/test_layer_array.cpp new file mode 100644 index 0000000..560bdb6 --- /dev/null +++ b/tools/test/test_wavenet/test_layer_array.cpp @@ -0,0 +1,133 @@ +// Tests for WaveNet LayerArray + +#include +#include +#include +#include +#include + +#include "NAM/wavenet.h" + +namespace test_wavenet +{ +namespace test_layer_array +{ +// Test layer array construction and basic processing +void test_layer_array_basic() +{ + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 1; + std::vector dilations{1, 2}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + + auto layer_array = nam::wavenet::_LayerArray( + input_size, condition_size, head_size, channels, kernel_size, dilations, activation, gated, head_bias); + + const int numFrames = 4; + layer_array.SetMaxBufferSize(numFrames); + + // Calculate expected number of weights + // Rechannel: (1,1) weight (no bias) + // Layer 0: conv (1,1,1) + bias, input_mixin (1,1), 1x1 (1,1) + bias + // Layer 1: conv (1,1,1) + bias, input_mixin (1,1), 1x1 (1,1) + bias + // Head rechannel: (1,1) weight (no bias) + std::vector weights; + // Rechannel + weights.push_back(1.0f); + // Layer 0: conv (weight=1, bias=0), input_mixin (weight=1), 1x1 (weight=1, bias=0) + weights.insert(weights.end(), {1.0f, 0.0f, 1.0f, 1.0f, 0.0f}); + // Layer 1: conv (weight=1, bias=0), input_mixin (weight=1), 1x1 (weight=1, bias=0) + weights.insert(weights.end(), {1.0f, 0.0f, 1.0f, 1.0f, 0.0f}); + // Head rechannel + weights.push_back(1.0f); + + auto it = weights.begin(); + layer_array.set_weights_(it); + assert(it == weights.end()); + + Eigen::MatrixXf layer_inputs(input_size, numFrames); + Eigen::MatrixXf condition(condition_size, numFrames); + layer_inputs.fill(1.0f); + condition.fill(1.0f); + + layer_array.Process(layer_inputs, condition, numFrames); + + auto layer_outputs = layer_array.GetLayerOutputs().leftCols(numFrames); + auto head_outputs = layer_array.GetHeadOutputs().leftCols(numFrames); + + assert(layer_outputs.rows() == channels); + assert(layer_outputs.cols() == numFrames); + assert(head_outputs.rows() == head_size); + assert(head_outputs.cols() == numFrames); +} + +// Test layer array receptive field calculation +void test_layer_array_receptive_field() +{ + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 3; + std::vector dilations{1, 2, 4}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + + auto layer_array = nam::wavenet::_LayerArray( + input_size, condition_size, head_size, channels, kernel_size, dilations, activation, gated, head_bias); + + long rf = layer_array.get_receptive_field(); + // Expected: sum of dilation * (kernel_size - 1) for each layer + // Layer 0: 1 * (3-1) = 2 + // Layer 1: 2 * (3-1) = 4 + // Layer 2: 4 * (3-1) = 8 + // Total: 2 + 4 + 8 = 14 + long expected_rf = 1 * (kernel_size - 1) + 2 * (kernel_size - 1) + 4 * (kernel_size - 1); + assert(rf == expected_rf); +} + +// Test layer array with head input from previous array +void test_layer_array_with_head_input() +{ + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 1; + std::vector dilations{1}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + + auto layer_array = nam::wavenet::_LayerArray( + input_size, condition_size, head_size, channels, kernel_size, dilations, activation, gated, head_bias); + + const int numFrames = 2; + layer_array.SetMaxBufferSize(numFrames); + + std::vector weights{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f}; + auto it = weights.begin(); + layer_array.set_weights_(it); + + Eigen::MatrixXf layer_inputs(input_size, numFrames); + Eigen::MatrixXf condition(condition_size, numFrames); + Eigen::MatrixXf head_inputs(head_size, numFrames); + layer_inputs.fill(1.0f); + condition.fill(1.0f); + head_inputs.fill(0.5f); + + layer_array.Process(layer_inputs, condition, head_inputs, numFrames); + + auto head_outputs = layer_array.GetHeadOutputs().leftCols(numFrames); + assert(head_outputs.rows() == head_size); + assert(head_outputs.cols() == numFrames); +} +}; // namespace test_layer_array + +} // namespace test_wavenet diff --git a/tools/test/test_wavenet/test_real_time_safe.cpp b/tools/test/test_wavenet/test_real_time_safe.cpp new file mode 100644 index 0000000..21d36b7 --- /dev/null +++ b/tools/test/test_wavenet/test_real_time_safe.cpp @@ -0,0 +1,456 @@ +// Test to verify WaveNet::process is real-time safe (no allocations/frees) + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "NAM/wavenet.h" +#include "NAM/conv1d.h" + +// Allocation tracking +namespace +{ +volatile int g_allocation_count = 0; +volatile int g_deallocation_count = 0; +volatile bool g_tracking_enabled = false; + +// Original malloc/free functions +void* (*original_malloc)(size_t) = nullptr; +void (*original_free)(void*) = nullptr; +void* (*original_realloc)(void*, size_t) = nullptr; +} // namespace + +// Override malloc/free to track Eigen allocations (Eigen uses malloc directly) +extern "C" { +void* malloc(size_t size) +{ + if (!original_malloc) + original_malloc = reinterpret_cast(dlsym(RTLD_NEXT, "malloc")); + void* ptr = original_malloc(size); + if (g_tracking_enabled && ptr != nullptr) + ++g_allocation_count; + return ptr; +} + +void free(void* ptr) +{ + if (!original_free) + original_free = reinterpret_cast(dlsym(RTLD_NEXT, "free")); + if (g_tracking_enabled && ptr != nullptr) + ++g_deallocation_count; + original_free(ptr); +} + +void* realloc(void* ptr, size_t size) +{ + if (!original_realloc) + original_realloc = reinterpret_cast(dlsym(RTLD_NEXT, "realloc")); + void* new_ptr = original_realloc(ptr, size); + if (g_tracking_enabled) + { + if (ptr != nullptr && new_ptr != ptr) + ++g_deallocation_count; // Old pointer was freed + if (new_ptr != nullptr && new_ptr != ptr) + ++g_allocation_count; // New allocation + } + return new_ptr; +} +} + +// Overload global new/delete operators to track allocations +void* operator new(std::size_t size) +{ + void* ptr = std::malloc(size); + if (!ptr) + throw std::bad_alloc(); + if (g_tracking_enabled) + ++g_allocation_count; + return ptr; +} + +void* operator new[](std::size_t size) +{ + void* ptr = std::malloc(size); + if (!ptr) + throw std::bad_alloc(); + if (g_tracking_enabled) + ++g_allocation_count; + return ptr; +} + +void operator delete(void* ptr) noexcept +{ + if (g_tracking_enabled && ptr != nullptr) + ++g_deallocation_count; + std::free(ptr); +} + +void operator delete[](void* ptr) noexcept +{ + if (g_tracking_enabled && ptr != nullptr) + ++g_deallocation_count; + std::free(ptr); +} + +namespace test_wavenet +{ +// Test that pre-allocated Eigen operations with noalias() don't allocate +void test_allocation_tracking_pass() +{ + const int rows = 10; + const int cols = 20; + + // Pre-allocate matrices for matrix product: c = a * b + // a is rows x cols, b is cols x rows, so c is rows x rows + Eigen::MatrixXf a(rows, cols); + Eigen::MatrixXf b(cols, rows); + Eigen::MatrixXf c(rows, rows); + + a.setConstant(1.0f); + b.setConstant(2.0f); + + // Reset allocation counters + g_allocation_count = 0; + g_deallocation_count = 0; + g_tracking_enabled = true; + + // Matrix product with noalias() - should not allocate (all matrices pre-allocated) + // Using noalias() is important for matrix products to avoid unnecessary temporaries + // Note: Even without noalias(), Eigen may avoid temporaries when matrices are distinct, + // but noalias() is best practice for real-time safety + c.noalias() = a * b; + + // Disable tracking before any cleanup + g_tracking_enabled = false; + + // Assert no allocations or frees occurred + assert(g_allocation_count == 0 && "Matrix product with noalias() allocated memory (unexpected)"); + assert(g_deallocation_count == 0 && "Matrix product with noalias() freed memory (unexpected)"); + + // Verify result: c should be rows x rows with value 2*cols (each element is sum of cols elements of value 2) + assert(c.rows() == rows && c.cols() == rows); + assert(std::abs(c(0, 0) - 2.0f * cols) < 0.001f); +} + +// Test that resizing a matrix causes allocations (should be caught) +void test_allocation_tracking_fail() +{ + const int rows = 10; + const int cols = 20; + + // Pre-allocate matrix + Eigen::MatrixXf a(rows, cols); + a.setConstant(1.0f); + + // Reset allocation counters + g_allocation_count = 0; + g_deallocation_count = 0; + g_tracking_enabled = true; + + // This operation should allocate (resizing requires reallocation) + a.resize(rows * 2, cols * 2); + + // Disable tracking before any cleanup + g_tracking_enabled = false; + + // Assert that allocations occurred (this test verifies our tracking works) + // Note: This test is meant to verify the tracking mechanism works, + // so we expect allocations/deallocations here + assert((g_allocation_count > 0 || g_deallocation_count > 0) + && "Matrix resize should have caused allocations (tracking may not be working)"); +} + +// Test that Conv1D::Process() method does not allocate or free memory +void test_conv1d_process_realtime_safe() +{ + // Setup: Create a Conv1D + const int in_channels = 1; + const int out_channels = 1; + const int kernel_size = 1; + const bool do_bias = false; + const int dilation = 1; + + nam::Conv1D conv; + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + // Set weights: simple identity + std::vector weights{1.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + + const int maxBufferSize = 256; + conv.SetMaxBufferSize(maxBufferSize); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input matrix (allocate before tracking) + Eigen::MatrixXf input(in_channels, buffer_size); + input.setConstant(0.5f); + + // Reset allocation counters + g_allocation_count = 0; + g_deallocation_count = 0; + g_tracking_enabled = true; + + // Call Process() - this should not allocate or free + conv.Process(input, buffer_size); + + // Disable tracking before any cleanup + g_tracking_enabled = false; + + // Debug output + if (g_allocation_count > 0 || g_deallocation_count > 0) + { + std::cerr << "Conv1D Process - Buffer size " << buffer_size << ": allocations=" << g_allocation_count + << ", deallocations=" << g_deallocation_count << "\n"; + } + + // Assert no allocations or frees occurred + if (g_allocation_count != 0 || g_deallocation_count != 0) + { + std::cerr << "ERROR: Conv1D Process - Buffer size " << buffer_size << " - allocated " << g_allocation_count + << " times, freed " << g_deallocation_count << " times (not real-time safe)\n"; + std::abort(); + } + + // Verify output is valid + auto output = conv.GetOutput().leftCols(buffer_size); + assert(output.rows() == out_channels && output.cols() == buffer_size); + assert(std::isfinite(output(0, 0))); + } +} + +// Test that Layer::Process() method does not allocate or free memory +void test_layer_process_realtime_safe() +{ + // Setup: Create a Layer + const int condition_size = 1; + const int channels = 1; + const int kernel_size = 1; + const int dilation = 1; + const std::string activation = "ReLU"; + const bool gated = false; + + auto layer = nam::wavenet::_Layer(condition_size, channels, kernel_size, dilation, activation, gated); + + // Set weights + std::vector weights{1.0f, 0.0f, // Conv (weight, bias) + 1.0f, // Input mixin + 1.0f, 0.0f}; // 1x1 (weight, bias) + auto it = weights.begin(); + layer.set_weights_(it); + + const int maxBufferSize = 256; + layer.SetMaxBufferSize(maxBufferSize); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf input(channels, buffer_size); + Eigen::MatrixXf condition(condition_size, buffer_size); + input.setConstant(0.5f); + condition.setConstant(0.5f); + + // Reset allocation counters + g_allocation_count = 0; + g_deallocation_count = 0; + g_tracking_enabled = true; + + // Call Process() - this should not allocate or free + layer.Process(input, condition, buffer_size); + + // Disable tracking before any cleanup + g_tracking_enabled = false; + + // Debug output + if (g_allocation_count > 0 || g_deallocation_count > 0) + { + std::cerr << "Layer Process - Buffer size " << buffer_size << ": allocations=" << g_allocation_count + << ", deallocations=" << g_deallocation_count << "\n"; + } + + // Assert no allocations or frees occurred + if (g_allocation_count != 0 || g_deallocation_count != 0) + { + std::cerr << "ERROR: Layer Process - Buffer size " << buffer_size << " - allocated " << g_allocation_count + << " times, freed " << g_deallocation_count << " times (not real-time safe)\n"; + std::abort(); + } + + // Verify output is valid + auto output = layer.GetOutputNextLayer().leftCols(buffer_size); + assert(output.rows() == channels && output.cols() == buffer_size); + assert(std::isfinite(output(0, 0))); + } +} + +// Test that LayerArray::Process() method does not allocate or free memory +void test_layer_array_process_realtime_safe() +{ + // Setup: Create LayerArray + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 1; + std::vector dilations{1}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + + auto layer_array = nam::wavenet::_LayerArray( + input_size, condition_size, head_size, channels, kernel_size, dilations, activation, gated, head_bias); + + // Set weights: rechannel(1), layer(conv:1+1, input_mixin:1, 1x1:1+1), head_rechannel(1) + std::vector weights{1.0f, // Rechannel + 1.0f, 0.0f, // Layer: conv + 1.0f, // Layer: input_mixin + 1.0f, 0.0f, // Layer: 1x1 + 1.0f}; // Head rechannel + auto it = weights.begin(); + layer_array.set_weights_(it); + + const int maxBufferSize = 256; + layer_array.SetMaxBufferSize(maxBufferSize); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/condition matrices (allocate before tracking) + Eigen::MatrixXf layer_inputs(input_size, buffer_size); + Eigen::MatrixXf condition(condition_size, buffer_size); + layer_inputs.setConstant(0.5f); + condition.setConstant(0.5f); + + // Reset allocation counters + g_allocation_count = 0; + g_deallocation_count = 0; + g_tracking_enabled = true; + + // Call Process() - this should not allocate or free + layer_array.Process(layer_inputs, condition, buffer_size); + + // Disable tracking before any cleanup + g_tracking_enabled = false; + + // Debug output + if (g_allocation_count > 0 || g_deallocation_count > 0) + { + std::cerr << "LayerArray Process - Buffer size " << buffer_size << ": allocations=" << g_allocation_count + << ", deallocations=" << g_deallocation_count << "\n"; + } + + // Assert no allocations or frees occurred + if (g_allocation_count != 0 || g_deallocation_count != 0) + { + std::cerr << "ERROR: LayerArray Process - Buffer size " << buffer_size << " - allocated " << g_allocation_count + << " times, freed " << g_deallocation_count << " times (not real-time safe)\n"; + std::abort(); + } + + // Verify output is valid + auto layer_outputs = layer_array.GetLayerOutputs().leftCols(buffer_size); + auto head_outputs = layer_array.GetHeadOutputs().leftCols(buffer_size); + assert(layer_outputs.rows() == channels && layer_outputs.cols() == buffer_size); + assert(head_outputs.rows() == head_size && head_outputs.cols() == buffer_size); + assert(std::isfinite(layer_outputs(0, 0))); + assert(std::isfinite(head_outputs(0, 0))); + } +} + +// Test that WaveNet::process() method does not allocate or free memory +void test_process_realtime_safe() +{ + // Setup: Create WaveNet with two layer arrays (simplified configuration) + const int input_size = 1; + const int condition_size = 1; + const int head_size = 1; + const int channels = 1; + const int kernel_size = 1; + std::vector dilations{1}; + const std::string activation = "ReLU"; + const bool gated = false; + const bool head_bias = false; + const float head_scale = 1.0f; + const bool with_head = false; + + std::vector layer_array_params; + // First layer array + layer_array_params.emplace_back( + input_size, condition_size, head_size, channels, kernel_size, std::vector{1}, activation, gated, head_bias); + // Second layer array (head_size of first must match channels of second) + layer_array_params.emplace_back( + head_size, condition_size, head_size, channels, kernel_size, std::vector{1}, activation, gated, head_bias); + + // Weights: Array 0: rechannel(1), layer(conv:1+1, input_mixin:1, 1x1:1+1), head_rechannel(1) + // Array 1: same structure + // Head scale: 1 + std::vector weights; + // Array 0: rechannel, layer, head_rechannel + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f}); + // Array 1: rechannel, layer, head_rechannel + weights.insert(weights.end(), {1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f}); + weights.push_back(head_scale); + + auto wavenet = std::make_unique(layer_array_params, head_scale, with_head, weights, 48000.0); + + const int maxBufferSize = 256; + wavenet->Reset(48000.0, maxBufferSize); + + // Test with several different buffer sizes + std::vector buffer_sizes{1, 8, 16, 32, 64, 128, 256}; + + for (int buffer_size : buffer_sizes) + { + // Prepare input/output buffers (allocate before tracking) + std::vector input(buffer_size, 0.5f); + std::vector output(buffer_size, 0.0f); + + // Reset allocation counters + g_allocation_count = 0; + g_deallocation_count = 0; + g_tracking_enabled = true; + + // Call process() - this should not allocate or free + wavenet->process(input.data(), output.data(), buffer_size); + + // Disable tracking before any cleanup + g_tracking_enabled = false; + + // Debug output + if (g_allocation_count > 0 || g_deallocation_count > 0) + { + std::cerr << "Buffer size " << buffer_size << ": allocations=" << g_allocation_count + << ", deallocations=" << g_deallocation_count << "\n"; + } + + // Assert no allocations or frees occurred + if (g_allocation_count != 0 || g_deallocation_count != 0) + { + std::cerr << "ERROR: Buffer size " << buffer_size << " - process() allocated " << g_allocation_count + << " times, freed " << g_deallocation_count << " times (not real-time safe)\n"; + std::abort(); + } + + // Verify output is valid + for (int i = 0; i < buffer_size; i++) + { + assert(std::isfinite(output[i])); + } + } +} +} // namespace test_wavenet