Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ be cheap to copy. When returning shared pointers, they should be returned by
value rather than by reference. Methods should be marked as `const` whenever
they do not modify the object's state.

#### Thread Safety

MaterialX classes support multiple concurrent readers, but not concurrent
reads and writes, following the pattern of standard C++ containers. This
design enables efficient parallel processing in read-heavy workloads such
as shader generation and scene traversal, while keeping the implementation
simple and avoiding the overhead of fine-grained locking.

#### Exception Handling

Exceptions should be used for exceptional conditions rather than for normal
Expand Down
194 changes: 110 additions & 84 deletions source/MaterialXCore/Document.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <MaterialXCore/Document.h>

#include <mutex>
#include <shared_mutex>

MATERIALX_NAMESPACE_BEGIN

Expand All @@ -29,82 +30,128 @@ class Document::Cache
{
public:
Cache() :
valid(false)
_valid(false)
{
}
~Cache() = default;

void refresh()
void setDocument(weak_ptr<Document> document)
{
// Thread synchronization for multiple concurrent readers of a single document.
std::lock_guard<std::mutex> guard(mutex);
std::unique_lock<std::shared_mutex> lock(_mutex);
_doc = document;
_valid = false;
}

void invalidate()
{
std::unique_lock<std::shared_mutex> lock(_mutex);
_valid = false;
}

vector<PortElementPtr> getMatchingPorts(const string& nodeName)
{
auto lock = refreshWithLock();
auto it = _portElementMap.find(nodeName);
return (it != _portElementMap.end()) ? it->second : vector<PortElementPtr>();
}

vector<NodeDefPtr> getMatchingNodeDefs(const string& nodeName)
{
auto lock = refreshWithLock();
auto it = _nodeDefMap.find(nodeName);
return (it != _nodeDefMap.end()) ? it->second : vector<NodeDefPtr>();
}

vector<InterfaceElementPtr> getMatchingImplementations(const string& nodeDef)
{
auto lock = refreshWithLock();
auto it = _implementationMap.find(nodeDef);
return (it != _implementationMap.end()) ? it->second : vector<InterfaceElementPtr>();
}

if (!valid)
private:
std::shared_lock<std::shared_mutex> refreshWithLock()
{
std::shared_lock<std::shared_mutex> lock(_mutex);

if (_valid)
{
// Clear the existing cache.
portElementMap.clear();
nodeDefMap.clear();
implementationMap.clear();
return lock;
}

// Traverse the document to build a new cache.
for (ElementPtr elem : doc.lock()->traverseTree())
{
const string& nodeName = elem->getAttribute(PortElement::NODE_NAME_ATTRIBUTE);
const string& nodeGraphName = elem->getAttribute(PortElement::NODE_GRAPH_ATTRIBUTE);
const string& nodeString = elem->getAttribute(NodeDef::NODE_ATTRIBUTE);
const string& nodeDefString = elem->getAttribute(InterfaceElement::NODE_DEF_ATTRIBUTE);
lock.unlock();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That is the best currently possible with C++ (up to 23). Rechecking the _valid condition after re-acquiring the lock makes sure we handle the case where another thread pre-emptively rebuilds in the small window after the shared lock is released, especially since this thread now has to wait for all shared locks to be released before proceeding.

LGTM!


if (!nodeName.empty())
{
std::unique_lock<std::shared_mutex> writeLock(_mutex);
if (!_valid)
{
auto doc = _doc.lock();
if (doc)
{
PortElementPtr portElem = elem->asA<PortElement>();
if (portElem)
{
portElementMap[portElem->getQualifiedName(nodeName)].push_back(portElem);
}
rebuild(doc);
}
else
}
}

lock.lock();
return lock;
}

void rebuild(DocumentPtr doc)
{
// Clear the existing cache.
_portElementMap.clear();
_nodeDefMap.clear();
_implementationMap.clear();

// Traverse the document to build a new cache.
for (ElementPtr elem : doc->traverseTree())
{
const string& nodeName = elem->getAttribute(PortElement::NODE_NAME_ATTRIBUTE);
const string& nodeGraphName = elem->getAttribute(PortElement::NODE_GRAPH_ATTRIBUTE);
const string& nodeString = elem->getAttribute(NodeDef::NODE_ATTRIBUTE);
const string& nodeDefString = elem->getAttribute(InterfaceElement::NODE_DEF_ATTRIBUTE);

const string& portKey = !nodeName.empty() ? nodeName : nodeGraphName;
if (!portKey.empty())
{
PortElementPtr portElem = elem->asA<PortElement>();
if (portElem)
{
if (!nodeGraphName.empty())
{
PortElementPtr portElem = elem->asA<PortElement>();
if (portElem)
{
portElementMap[portElem->getQualifiedName(nodeGraphName)].push_back(portElem);
}
}
_portElementMap[portElem->getQualifiedName(portKey)].push_back(portElem);
}
if (!nodeString.empty())
}
if (!nodeString.empty())
{
NodeDefPtr nodeDef = elem->asA<NodeDef>();
if (nodeDef)
{
NodeDefPtr nodeDef = elem->asA<NodeDef>();
if (nodeDef)
{
nodeDefMap[nodeDef->getQualifiedName(nodeString)].push_back(nodeDef);
}
_nodeDefMap[nodeDef->getQualifiedName(nodeString)].push_back(nodeDef);
}
if (!nodeDefString.empty())
}
if (!nodeDefString.empty())
{
InterfaceElementPtr interface = elem->asA<InterfaceElement>();
if (interface)
{
InterfaceElementPtr interface = elem->asA<InterfaceElement>();
if (interface)
if (interface->isA<Implementation>() || interface->isA<NodeGraph>())
{
if (interface->isA<Implementation>() || interface->isA<NodeGraph>())
{
implementationMap[interface->getQualifiedName(nodeDefString)].push_back(interface);
}
_implementationMap[interface->getQualifiedName(nodeDefString)].push_back(interface);
}
}
}

valid = true;
}

_valid = true;
}

public:
weak_ptr<Document> doc;
std::mutex mutex;
bool valid;
std::unordered_map<string, std::vector<PortElementPtr>> portElementMap;
std::unordered_map<string, std::vector<NodeDefPtr>> nodeDefMap;
std::unordered_map<string, std::vector<InterfaceElementPtr>> implementationMap;
private:
weak_ptr<Document> _doc;
mutable std::shared_mutex _mutex;
bool _valid;
std::unordered_map<string, std::vector<PortElementPtr>> _portElementMap;
std::unordered_map<string, std::vector<NodeDefPtr>> _nodeDefMap;
std::unordered_map<string, std::vector<InterfaceElementPtr>> _implementationMap;
};

//
Expand All @@ -124,7 +171,7 @@ Document::~Document()
void Document::initialize()
{
_root = getSelf();
_cache->doc = getDocument();
_cache->setDocument(getDocument());

clearContent();
setVersionIntegers(MATERIALX_MAJOR_VERSION, MATERIALX_MINOR_VERSION);
Expand Down Expand Up @@ -284,18 +331,7 @@ std::pair<int, int> Document::getVersionIntegers() const

vector<PortElementPtr> Document::getMatchingPorts(const string& nodeName) const
{
// Refresh the cache.
_cache->refresh();

// Return all port elements matching the given node name.
if (_cache->portElementMap.count(nodeName))
{
return _cache->portElementMap.at(nodeName);
}
else
{
return vector<PortElementPtr>();
}
return _cache->getMatchingPorts(nodeName);
}

ValuePtr Document::getGeomPropValue(const string& geomPropName, const string& geom) const
Expand Down Expand Up @@ -342,19 +378,14 @@ vector<OutputPtr> Document::getMaterialOutputs() const
vector<NodeDefPtr> Document::getMatchingNodeDefs(const string& nodeName) const
{
// Recurse to data library if present.
vector<NodeDefPtr> matchingNodeDefs = hasDataLibrary() ?
vector<NodeDefPtr> matchingNodeDefs = hasDataLibrary() ?
getDataLibrary()->getMatchingNodeDefs(nodeName) :
vector<NodeDefPtr>();

// Refresh the cache.
_cache->refresh();
// Append all nodedefs matching the given node name.
vector<NodeDefPtr> localNodeDefs = _cache->getMatchingNodeDefs(nodeName);
matchingNodeDefs.insert(matchingNodeDefs.end(), localNodeDefs.begin(), localNodeDefs.end());

// Return all nodedefs matching the given node name.
if (_cache->nodeDefMap.count(nodeName))
{
matchingNodeDefs.insert(matchingNodeDefs.end(), _cache->nodeDefMap.at(nodeName).begin(), _cache->nodeDefMap.at(nodeName).end());
}

return matchingNodeDefs;
}

Expand All @@ -364,15 +395,10 @@ vector<InterfaceElementPtr> Document::getMatchingImplementations(const string& n
vector<InterfaceElementPtr> matchingImplementations = hasDataLibrary() ?
getDataLibrary()->getMatchingImplementations(nodeDef) :
vector<InterfaceElementPtr>();

// Refresh the cache.
_cache->refresh();

// Return all implementations matching the given nodedef string.
if (_cache->implementationMap.count(nodeDef))
{
matchingImplementations.insert(matchingImplementations.end(), _cache->implementationMap.at(nodeDef).begin(), _cache->implementationMap.at(nodeDef).end());
}
// Append all implementations matching the given nodedef string.
vector<InterfaceElementPtr> localImpls = _cache->getMatchingImplementations(nodeDef);
matchingImplementations.insert(matchingImplementations.end(), localImpls.begin(), localImpls.end());

return matchingImplementations;
}
Expand All @@ -388,7 +414,7 @@ bool Document::validate(string* message) const

void Document::invalidateCache()
{
_cache->valid = false;
_cache->invalidate();
}

//
Expand Down