diff --git a/NAM/activations.cpp b/NAM/activations.cpp index 828cea8..b01060b 100644 --- a/NAM/activations.cpp +++ b/NAM/activations.cpp @@ -7,39 +7,33 @@ nam::activations::ActivationReLU _RELU = nam::activations::ActivationReLU(); nam::activations::ActivationLeakyReLU _LEAKY_RELU = nam::activations::ActivationLeakyReLU(); nam::activations::ActivationSigmoid _SIGMOID = nam::activations::ActivationSigmoid(); -bool nam::activations::Activation::using_fast_tanh = false; +std::atomic nam::activations::Activation::using_fast_tanh{false}; std::unordered_map nam::activations::Activation::_activations = { {"Tanh", &_TANH}, {"Hardtanh", &_HARD_TANH}, {"Fasttanh", &_FAST_TANH}, {"ReLU", &_RELU}, {"LeakyReLU", &_LEAKY_RELU}, {"Sigmoid", &_SIGMOID}}; -nam::activations::Activation* tanh_bak = nullptr; - nam::activations::Activation* nam::activations::Activation::get_activation(const std::string name) { - if (_activations.find(name) == _activations.end()) + // Return FastTanh when Tanh is requested and fast_tanh mode is enabled + if (name == "Tanh" && using_fast_tanh.load(std::memory_order_relaxed)) + { + return _activations.at("Fasttanh"); + } + + auto it = _activations.find(name); + if (it == _activations.end()) return nullptr; - return _activations[name]; + return it->second; } void nam::activations::Activation::enable_fast_tanh() { - nam::activations::Activation::using_fast_tanh = true; - - if (_activations["Tanh"] != _activations["Fasttanh"]) - { - tanh_bak = _activations["Tanh"]; - _activations["Tanh"] = _activations["Fasttanh"]; - } + using_fast_tanh.store(true, std::memory_order_relaxed); } void nam::activations::Activation::disable_fast_tanh() { - nam::activations::Activation::using_fast_tanh = false; - - if (_activations["Tanh"] == _activations["Fasttanh"]) - { - _activations["Tanh"] = tanh_bak; - } + using_fast_tanh.store(false, std::memory_order_relaxed); } diff --git a/NAM/activations.h b/NAM/activations.h index fe47203..afa28fe 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -1,7 +1,8 @@ #pragma once -#include +#include #include // expf +#include #include #include @@ -63,7 +64,7 @@ class Activation static Activation* get_activation(const std::string name); static void enable_fast_tanh(); static void disable_fast_tanh(); - static bool using_fast_tanh; + static std::atomic using_fast_tanh; protected: static std::unordered_map _activations;