// // Copyright Contributors to the MaterialX Project // SPDX-License-Identifier: Apache-2.0 // #include #include #include #include namespace mx = MaterialX; TEST_CASE("IntraGraph Traversal", "[traversal]") { // Test null iterators. mx::TreeIterator nullTree = mx::getNullTreeIterator(); mx::GraphIterator nullGraph = mx::getNullGraphIterator(); REQUIRE(*nullTree == nullptr); REQUIRE(*nullGraph == mx::getNullEdge()); ++nullTree; ++nullGraph; REQUIRE((nullTree == mx::getNullTreeIterator())); REQUIRE((nullGraph == mx::getNullGraphIterator())); // Create a document. mx::DocumentPtr doc = mx::createDocument(); // Create a node graph with the following structure: // // [image1] [constant] [image2] // \ / | // [multiply] [contrast] [noise3d] // \____________ | ____________/ // [mix] // | // [output] // mx::NodeGraphPtr nodeGraph = doc->addNodeGraph(); mx::NodePtr image1 = nodeGraph->addNode("image"); mx::NodePtr image2 = nodeGraph->addNode("image"); mx::NodePtr constant = nodeGraph->addNode("constant"); mx::NodePtr multiply = nodeGraph->addNode("multiply"); mx::NodePtr contrast = nodeGraph->addNode("contrast"); mx::NodePtr noise3d = nodeGraph->addNode("noise3d"); mx::NodePtr mix = nodeGraph->addNode("mix"); mx::OutputPtr output = nodeGraph->addOutput(); multiply->setConnectedNode("in1", image1); multiply->setConnectedNode("in2", constant); contrast->setConnectedNode("in", image2); mix->setConnectedNode("fg", multiply); mix->setConnectedNode("bg", contrast); mix->setConnectedNode("mask", noise3d); output->setConnectedNode(mix); // Validate the document. REQUIRE(doc->validate()); // Traverse the document tree (implicit iterator). int nodeCount = 0; for (mx::ElementPtr elem : doc->traverseTree()) { REQUIRE(elem->getName() == mx::createValidName(elem->getName())); if (elem->isA()) { nodeCount++; } } REQUIRE(nodeCount == 7); // Traverse the document tree (explicit iterator). nodeCount = 0; size_t maxElementDepth = 0; for (mx::TreeIterator it = doc->traverseTree().begin(); it != mx::TreeIterator::end(); ++it) { mx::ElementPtr elem = it.getElement(); if (elem->isA()) { nodeCount++; } maxElementDepth = std::max(maxElementDepth, it.getElementDepth()); } REQUIRE(nodeCount == 7); REQUIRE(maxElementDepth == 3); // Traverse the document tree (prune subtree). nodeCount = 0; for (mx::TreeIterator it = doc->traverseTree().begin(); it != mx::TreeIterator::end(); ++it) { mx::ElementPtr elem = it.getElement(); if (elem->isA()) { nodeCount++; } if (elem->isA()) { it.setPruneSubtree(true); } } REQUIRE(nodeCount == 0); // Traverse upstream from the graph output (implicit iterator). nodeCount = 0; for (mx::Edge edge : output->traverseGraph()) { mx::ElementPtr upstreamElem = edge.getUpstreamElement(); mx::ElementPtr connectingElem = edge.getConnectingElement(); mx::ElementPtr downstreamElem = edge.getDownstreamElement(); if (upstreamElem->isA()) { nodeCount++; if (downstreamElem->isA()) { REQUIRE(connectingElem->isA()); } } } REQUIRE(nodeCount == 7); // Traverse upstream from the graph output (explicit iterator). nodeCount = 0; maxElementDepth = 0; size_t maxNodeDepth = 0; for (mx::GraphIterator it = output->traverseGraph().begin(); it != mx::GraphIterator::end(); ++it) { mx::ElementPtr upstreamElem = it.getUpstreamElement(); mx::ElementPtr connectingElem = it.getConnectingElement(); mx::ElementPtr downstreamElem = it.getDownstreamElement(); if (upstreamElem->isA()) { nodeCount++; if (downstreamElem->isA()) { REQUIRE(connectingElem->isA()); } } maxElementDepth = std::max(maxElementDepth, it.getElementDepth()); maxNodeDepth = std::max(maxNodeDepth, it.getNodeDepth()); } REQUIRE(nodeCount == 7); REQUIRE(maxElementDepth == 3); REQUIRE(maxNodeDepth == 3); // Traverse upstream from the graph output (prune subgraph). nodeCount = 0; for (mx::GraphIterator it = output->traverseGraph().begin(); it != mx::GraphIterator::end(); ++it) { mx::ElementPtr upstreamElem = it.getUpstreamElement(); if (upstreamElem->isA()) { nodeCount++; if (upstreamElem->getCategory() == "multiply") { it.setPruneSubgraph(true); } } } REQUIRE(nodeCount == 5); // Create and detect a cycle. multiply->setConnectedNode("in2", mix); REQUIRE(output->hasUpstreamCycle()); REQUIRE(!doc->validate()); multiply->setConnectedNode("in2", constant); REQUIRE(!output->hasUpstreamCycle()); REQUIRE(doc->validate()); // Create and detect a loop. contrast->setConnectedNode("in", contrast); REQUIRE(output->hasUpstreamCycle()); REQUIRE(!doc->validate()); contrast->setConnectedNode("in", image2); REQUIRE(!output->hasUpstreamCycle()); REQUIRE(doc->validate()); } TEST_CASE("InterGraph Traversal", "[traversal]") { mx::FileSearchPath searchPath = mx::getDefaultDataSearchPath(); mx::DocumentPtr doc = mx::createDocument(); mx::readFromXmlFile(doc, "resources/Materials/TestSuite/stdlib/nodegraphs/nodegraph_nodegraph.mtlx", searchPath); for (mx::NodeGraphPtr graph : doc->getNodeGraphs()) { for (mx::InputPtr interfaceInput : graph->getInputs()) { if (!interfaceInput->getNodeName().empty() || !interfaceInput->getNodeGraphString().empty()) { REQUIRE(interfaceInput->getConnectedNode() != nullptr); } } } }