diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 72e8935923..b8c30d97db 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -242,6 +242,14 @@ be cheap to copy. When returning shared pointers, they should be returned by value rather than by reference. Methods should be marked as `const` whenever they do not modify the object's state. +#### Thread Safety + +MaterialX classes support multiple concurrent readers, but not concurrent +reads and writes, following the pattern of standard C++ containers. This +design enables efficient parallel processing in read-heavy workloads such +as shader generation and scene traversal, while keeping the implementation +simple and avoiding the overhead of fine-grained locking. + #### Exception Handling Exceptions should be used for exceptional conditions rather than for normal diff --git a/source/MaterialXCore/Document.cpp b/source/MaterialXCore/Document.cpp index 3bb8a5b106..ec8bb54e84 100644 --- a/source/MaterialXCore/Document.cpp +++ b/source/MaterialXCore/Document.cpp @@ -6,6 +6,7 @@ #include #include +#include MATERIALX_NAMESPACE_BEGIN @@ -29,82 +30,128 @@ class Document::Cache { public: Cache() : - valid(false) + _valid(false) { } ~Cache() = default; - void refresh() + void setDocument(weak_ptr document) { - // Thread synchronization for multiple concurrent readers of a single document. - std::lock_guard guard(mutex); + std::unique_lock lock(_mutex); + _doc = document; + _valid = false; + } + + void invalidate() + { + std::unique_lock lock(_mutex); + _valid = false; + } + + vector getMatchingPorts(const string& nodeName) + { + auto lock = refreshWithLock(); + auto it = _portElementMap.find(nodeName); + return (it != _portElementMap.end()) ? it->second : vector(); + } + + vector getMatchingNodeDefs(const string& nodeName) + { + auto lock = refreshWithLock(); + auto it = _nodeDefMap.find(nodeName); + return (it != _nodeDefMap.end()) ? it->second : vector(); + } + + vector getMatchingImplementations(const string& nodeDef) + { + auto lock = refreshWithLock(); + auto it = _implementationMap.find(nodeDef); + return (it != _implementationMap.end()) ? it->second : vector(); + } - if (!valid) + private: + std::shared_lock refreshWithLock() + { + std::shared_lock lock(_mutex); + + if (_valid) { - // Clear the existing cache. - portElementMap.clear(); - nodeDefMap.clear(); - implementationMap.clear(); + return lock; + } - // Traverse the document to build a new cache. - for (ElementPtr elem : doc.lock()->traverseTree()) - { - const string& nodeName = elem->getAttribute(PortElement::NODE_NAME_ATTRIBUTE); - const string& nodeGraphName = elem->getAttribute(PortElement::NODE_GRAPH_ATTRIBUTE); - const string& nodeString = elem->getAttribute(NodeDef::NODE_ATTRIBUTE); - const string& nodeDefString = elem->getAttribute(InterfaceElement::NODE_DEF_ATTRIBUTE); + lock.unlock(); - if (!nodeName.empty()) + { + std::unique_lock writeLock(_mutex); + if (!_valid) + { + auto doc = _doc.lock(); + if (doc) { - PortElementPtr portElem = elem->asA(); - if (portElem) - { - portElementMap[portElem->getQualifiedName(nodeName)].push_back(portElem); - } + rebuild(doc); } - else + } + } + + lock.lock(); + return lock; + } + + void rebuild(DocumentPtr doc) + { + // Clear the existing cache. + _portElementMap.clear(); + _nodeDefMap.clear(); + _implementationMap.clear(); + + // Traverse the document to build a new cache. + for (ElementPtr elem : doc->traverseTree()) + { + const string& nodeName = elem->getAttribute(PortElement::NODE_NAME_ATTRIBUTE); + const string& nodeGraphName = elem->getAttribute(PortElement::NODE_GRAPH_ATTRIBUTE); + const string& nodeString = elem->getAttribute(NodeDef::NODE_ATTRIBUTE); + const string& nodeDefString = elem->getAttribute(InterfaceElement::NODE_DEF_ATTRIBUTE); + + const string& portKey = !nodeName.empty() ? nodeName : nodeGraphName; + if (!portKey.empty()) + { + PortElementPtr portElem = elem->asA(); + if (portElem) { - if (!nodeGraphName.empty()) - { - PortElementPtr portElem = elem->asA(); - if (portElem) - { - portElementMap[portElem->getQualifiedName(nodeGraphName)].push_back(portElem); - } - } + _portElementMap[portElem->getQualifiedName(portKey)].push_back(portElem); } - if (!nodeString.empty()) + } + if (!nodeString.empty()) + { + NodeDefPtr nodeDef = elem->asA(); + if (nodeDef) { - NodeDefPtr nodeDef = elem->asA(); - if (nodeDef) - { - nodeDefMap[nodeDef->getQualifiedName(nodeString)].push_back(nodeDef); - } + _nodeDefMap[nodeDef->getQualifiedName(nodeString)].push_back(nodeDef); } - if (!nodeDefString.empty()) + } + if (!nodeDefString.empty()) + { + InterfaceElementPtr interface = elem->asA(); + if (interface) { - InterfaceElementPtr interface = elem->asA(); - if (interface) + if (interface->isA() || interface->isA()) { - if (interface->isA() || interface->isA()) - { - implementationMap[interface->getQualifiedName(nodeDefString)].push_back(interface); - } + _implementationMap[interface->getQualifiedName(nodeDefString)].push_back(interface); } } } - - valid = true; } + + _valid = true; } - public: - weak_ptr doc; - std::mutex mutex; - bool valid; - std::unordered_map> portElementMap; - std::unordered_map> nodeDefMap; - std::unordered_map> implementationMap; + private: + weak_ptr _doc; + mutable std::shared_mutex _mutex; + bool _valid; + std::unordered_map> _portElementMap; + std::unordered_map> _nodeDefMap; + std::unordered_map> _implementationMap; }; // @@ -124,7 +171,7 @@ Document::~Document() void Document::initialize() { _root = getSelf(); - _cache->doc = getDocument(); + _cache->setDocument(getDocument()); clearContent(); setVersionIntegers(MATERIALX_MAJOR_VERSION, MATERIALX_MINOR_VERSION); @@ -284,18 +331,7 @@ std::pair Document::getVersionIntegers() const vector Document::getMatchingPorts(const string& nodeName) const { - // Refresh the cache. - _cache->refresh(); - - // Return all port elements matching the given node name. - if (_cache->portElementMap.count(nodeName)) - { - return _cache->portElementMap.at(nodeName); - } - else - { - return vector(); - } + return _cache->getMatchingPorts(nodeName); } ValuePtr Document::getGeomPropValue(const string& geomPropName, const string& geom) const @@ -342,19 +378,14 @@ vector Document::getMaterialOutputs() const vector Document::getMatchingNodeDefs(const string& nodeName) const { // Recurse to data library if present. - vector matchingNodeDefs = hasDataLibrary() ? + vector matchingNodeDefs = hasDataLibrary() ? getDataLibrary()->getMatchingNodeDefs(nodeName) : vector(); - // Refresh the cache. - _cache->refresh(); + // Append all nodedefs matching the given node name. + vector localNodeDefs = _cache->getMatchingNodeDefs(nodeName); + matchingNodeDefs.insert(matchingNodeDefs.end(), localNodeDefs.begin(), localNodeDefs.end()); - // Return all nodedefs matching the given node name. - if (_cache->nodeDefMap.count(nodeName)) - { - matchingNodeDefs.insert(matchingNodeDefs.end(), _cache->nodeDefMap.at(nodeName).begin(), _cache->nodeDefMap.at(nodeName).end()); - } - return matchingNodeDefs; } @@ -364,15 +395,10 @@ vector Document::getMatchingImplementations(const string& n vector matchingImplementations = hasDataLibrary() ? getDataLibrary()->getMatchingImplementations(nodeDef) : vector(); - - // Refresh the cache. - _cache->refresh(); - // Return all implementations matching the given nodedef string. - if (_cache->implementationMap.count(nodeDef)) - { - matchingImplementations.insert(matchingImplementations.end(), _cache->implementationMap.at(nodeDef).begin(), _cache->implementationMap.at(nodeDef).end()); - } + // Append all implementations matching the given nodedef string. + vector localImpls = _cache->getMatchingImplementations(nodeDef); + matchingImplementations.insert(matchingImplementations.end(), localImpls.begin(), localImpls.end()); return matchingImplementations; } @@ -388,7 +414,7 @@ bool Document::validate(string* message) const void Document::invalidateCache() { - _cache->valid = false; + _cache->invalidate(); } //