This commit is contained in:
2026-01-06 13:25:49 +00:00
parent 5d495d731b
commit 4e15e08b7f
1395 changed files with 295666 additions and 323 deletions

View File

@@ -0,0 +1,18 @@
if(NOT SKBUILD)
install(DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/MaterialX" DESTINATION "python" MESSAGE_NEVER)
install(DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/Scripts" DESTINATION "python" MESSAGE_NEVER)
endif()
if(SKBUILD)
install(
DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/Scripts/"
DESTINATION "${MATERIALX_PYTHON_FOLDER_NAME}/_scripts"
PATTERN "README.md" EXCLUDE
)
endif()
if(MATERIALX_INSTALL_PYTHON AND PYTHON_EXECUTABLE AND NOT SKBUILD)
set(SETUP_PY "${CMAKE_INSTALL_PREFIX}/python/setup.py")
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in" "${SETUP_PY}")
install(CODE "execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install . WORKING_DIRECTORY ${CMAKE_INSTALL_PREFIX}/python)")
endif()

View File

@@ -0,0 +1 @@
recursive-include libraries

View File

@@ -0,0 +1,24 @@
# Python 3.8+ on Windows: DLL search paths for dependent
# shared libraries
# Refs.:
# - https://github.com/python/cpython/issues/80266
# - https://docs.python.org/3.8/library/os.html#os.add_dll_directory
import os
import sys
if sys.platform == "win32" and sys.version_info >= (3, 8):
import importlib.metadata
try:
importlib.metadata.version('MaterialX')
except importlib.metadata.PackageNotFoundError:
# On a non-pip installation, this file is in %INSTALLDIR%\python\MaterialX
# We need to add %INSTALLDIR%\bin to the DLL path.
mxdir = os.path.dirname(__file__)
pydir = os.path.split(mxdir)[0]
installdir = os.path.split(pydir)[0]
bindir = os.path.join(installdir, "bin")
if os.path.exists(bindir):
os.add_dll_directory(bindir)
from .main import *
__version__ = getVersionString()

View File

@@ -0,0 +1,2 @@
This directory is empty built it's used when packaging the Python library.
the files in ../../Scripts will be copied inside.

View File

@@ -0,0 +1 @@
# Only required for entry-points.

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python
'''
Native Python helper functions for MaterialX data types.
'''
import sys
from .PyMaterialXCore import *
#--------------------------------------------------------------------------------
_typeToName = { int : 'integer',
float : 'float',
bool : 'boolean',
Color3 : 'color3',
Color4 : 'color4',
Vector2 : 'vector2',
Vector3 : 'vector3',
Vector4 : 'vector4',
Matrix33 : 'matrix33',
Matrix44 : 'matrix44',
str : 'string' }
if sys.version_info[0] < 3:
_typeToName[long] = 'integer'
_typeToName[unicode] = 'string'
else:
_typeToName[bytes] = 'string'
#--------------------------------------------------------------------------------
def getTypeString(value):
"""Return the MaterialX type string associated with the given Python value
If the type of the given Python value is not recognized by MaterialX,
then None is returned.
Examples:
getTypeString(1.0) -> 'float'
getTypeString(mx.Color3(1)) -> 'color3'"""
valueType = type(value)
if valueType in _typeToName:
return _typeToName[valueType]
if valueType in (tuple, list):
if len(value):
elemType = type(value[0])
if elemType in _typeToName:
return _typeToName[elemType] + 'array'
return 'stringarray'
return None
def getValueString(value):
"""Return the MaterialX value string associated with the given Python value
If the type of the given Python value is not recognized by MaterialX,
then None is returned
Examples:
getValueString(0.1) -> '0.1'
getValueString(mx.Color3(0.1, 0.2, 0.3)) -> '0.1, 0.2, 0.3'"""
typeString = getTypeString(value)
if not typeString:
return None
method = globals()['TypedValue_' + typeString].createValue
return method(value).getValueString()
def createValueFromStrings(valueString, typeString):
"""Convert a MaterialX value and type strings to the corresponding
Python value. If the given conversion cannot be performed, then None
is returned.
Examples:
createValueFromStrings('0.1', 'float') -> 0.1
createValueFromStrings('0.1, 0.2, 0.3', 'color3') -> mx.Color3(0.1, 0.2, 0.3)"""
valueObj = Value.createValueFromStrings(valueString, typeString)
if not valueObj:
return None
return valueObj.getData()
def isColorType(t):
"Return True if the given type is a MaterialX color."
return t in (Color3, Color4)
def isColorValue(value):
"Return True if the given value is a MaterialX color."
return isColorType(type(value))
def stringToBoolean(value):
"Return boolean value found in a string. Throws and exception if a boolean value could not be parsed"
if isinstance(value, bool):
return value
if value.lower() in ('yes', 'true', 't', '1'):
return True
elif value.lower() in ('no', 'false', 'f', '0'):
return False
raise TypeError('Boolean value expected.')

View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python
'''
Native Python wrappers for PyMaterialX, providing a more Pythonic interface
for Elements and Values.
'''
import warnings
from .PyMaterialXCore import *
from .PyMaterialXFormat import *
from .datatype import *
import os
#
# Element
#
def _isA(self, elementClass, category = ''):
"""Return True if this element is an instance of the given subclass.
If a category string is specified, then both subclass and category
matches are required."""
if not isinstance(self, elementClass):
return False
if category and self.getCategory() != category:
return False
return True
def _addChild(self, elementClass, name, typeString = ''):
"Add a child element of the given subclass, name, and optional type string."
method = getattr(self.__class__, "_addChild" + elementClass.__name__)
return method(self, name, typeString)
def _getChild(self, name):
"Return the child element, if any, with the given name."
if (name == None):
return None
return self._getChild(name)
def _getChildOfType(self, elementClass, name):
"Return the child element, if any, with the given name and subclass."
method = getattr(self.__class__, "_getChildOfType" + elementClass.__name__)
return method(self, name)
def _getChildrenOfType(self, elementClass):
"""Return a list of all child elements that are instances of the given type.
The returned list maintains the order in which children were added."""
method = getattr(self.__class__, "_getChildrenOfType" + elementClass.__name__)
return method(self)
def _removeChildOfType(self, elementClass, name):
"Remove the typed child element, if any, with the given name."
method = getattr(self.__class__, "_removeChildOfType" + elementClass.__name__)
method(self, name)
Element.isA = _isA
Element.addChild = _addChild
Element.getChild = _getChild
Element.getChildOfType = _getChildOfType
Element.getChildrenOfType = _getChildrenOfType
Element.removeChildOfType = _removeChildOfType
#
# ValueElement
#
def _setValue(self, value, typeString = ''):
"Set the typed value of an element."
method = getattr(self.__class__, "_setValue" + getTypeString(value))
method(self, value, typeString)
def _getValue(self):
"Return the typed value of an element."
value = self._getValue()
return value.getData() if value else None
def _getDefaultValue(self):
"""Return the default value for this element."""
value = self._getDefaultValue()
return value.getData() if value else None
ValueElement.setValue = _setValue
ValueElement.getValue = _getValue
ValueElement.getDefaultValue = _getDefaultValue
#
# InterfaceElement
#
def _setInputValue(self, name, value, typeString = ''):
"""Set the typed value of an input by its name, creating a child element
to hold the input if needed."""
method = getattr(self.__class__, "_setInputValue" + getTypeString(value))
return method(self, name, value, typeString)
def _getInputValue(self, name, target = ''):
"""Return the typed value of an input by its name, taking both the
calling element and its declaration into account. If the given
input is not found, then None is returned."""
value = self._getInputValue(name, target)
return value.getData() if value else None
def _addParameter(self, name):
"""(Deprecated) Add a Parameter to this interface."""
warnings.warn("This function is deprecated; parameters have been replaced with uniform inputs in 1.38.", DeprecationWarning, stacklevel = 2)
return self.addInput(name)
def _getParameters(self):
"""(Deprecated) Return a vector of all Parameter elements."""
warnings.warn("This function is deprecated; parameters have been replaced with uniform inputs in 1.38.", DeprecationWarning, stacklevel = 2)
return list()
def _getActiveParameters(self):
"""(Deprecated) Return a vector of all parameters belonging to this interface, taking inheritance into account."""
warnings.warn("This function is deprecated; parameters have been replaced with uniform inputs in 1.38.", DeprecationWarning, stacklevel = 2)
return list()
def _setParameterValue(self, name, value, typeString = ''):
"""(Deprecated) Set the typed value of a parameter by its name."""
warnings.warn("This function is deprecated; parameters have been replaced with uniform inputs in 1.38.", DeprecationWarning, stacklevel = 2)
def _getParameterValue(self, name, target = ''):
"""(Deprecated) Return the typed value of a parameter by its name."""
warnings.warn("This function is deprecated; parameters have been replaced with uniform inputs in 1.38.", DeprecationWarning, stacklevel = 2)
return None
def _getParameterValueString(self, name):
"""(Deprecated) Return the value string of a parameter by its name."""
warnings.warn("This function is deprecated; parameters have been replaced with uniform inputs in 1.38.", DeprecationWarning, stacklevel = 2)
return ""
def _addBindInput(self, name, type = DEFAULT_TYPE_STRING):
"""(Deprecated) Add a BindInput to this shader reference."""
warnings.warn("This function is deprecated; shader references have been replaced with shader nodes in 1.38.", DeprecationWarning, stacklevel = 2)
return self.addInput(name, type)
def _getBindInputs(self):
"""(Deprecated) Return a vector of all BindInput elements in this shader reference."""
warnings.warn("This function is deprecated; shader references have been replaced with shader nodes in 1.38.", DeprecationWarning, stacklevel = 2)
return self.getInputs()
def _addBindParam(self, name, type = DEFAULT_TYPE_STRING):
"""(Deprecated) Add a BindParam to this shader reference."""
warnings.warn("This function is deprecated; shader references have been replaced with shader nodes in 1.38.", DeprecationWarning, stacklevel = 2)
return self.addInput(name, type)
def _getBindParams(self):
"""(Deprecated) Return a vector of all BindParam elements in this shader reference."""
warnings.warn("This function is deprecated; shader references have been replaced with shader nodes in 1.38.", DeprecationWarning, stacklevel = 2)
return list()
def _getBindTokens(self):
"""(Deprecated) Return a vector of all BindToken elements in this shader reference."""
warnings.warn("This function is deprecated; shader references have been replaced with shader nodes in 1.38.", DeprecationWarning, stacklevel = 2)
return list()
InterfaceElement.setInputValue = _setInputValue
InterfaceElement.getInputValue = _getInputValue
InterfaceElement.addParameter = _addParameter
InterfaceElement.getParameters = _getParameters
InterfaceElement.getActiveParameters = _getActiveParameters
InterfaceElement.setParameterValue = _setParameterValue
InterfaceElement.getParameterValue = _getParameterValue
InterfaceElement.getParameterValueString = _getParameterValueString
InterfaceElement.addBindInput = _addBindInput
InterfaceElement.getBindInputs = _getBindInputs
InterfaceElement.addBindParam = _addBindParam
InterfaceElement.getBindParams = _getBindParams
InterfaceElement.getBindTokens = _getBindTokens
#
# Node
#
def _getReferencedNodeDef(self):
"(Deprecated) Return the first NodeDef that declares this node."
warnings.warn("This function is deprecated; call Node.getNodeDef instead.", DeprecationWarning, stacklevel = 2)
return self.getNodeDef()
def _addShaderRef(self, name, nodeName):
"(Deprecated) Add a shader reference to this material element."
warnings.warn("This function is deprecated; material elements have been replaced with material nodes in 1.38.", DeprecationWarning, stacklevel = 2)
return self.getParent().addNode(nodeName, name)
def _getShaderRefs(self):
"""(Deprecated) Return a vector of all shader references in this material element."""
warnings.warn("This function is deprecated; shader references have been replaced with shader nodes in 1.38.", DeprecationWarning, stacklevel = 2)
return getShaderNodes(self)
def _getActiveShaderRefs(self):
"""(Deprecated) Return a vector of all shader references in this material element, taking material inheritance into account."""
warnings.warn("This function is deprecated; shader references have been replaced with shader nodes in 1.38.", DeprecationWarning, stacklevel = 2)
return getShaderNodes(self)
Node.getReferencedNodeDef = _getReferencedNodeDef
Node.addShaderRef = _addShaderRef
Node.getShaderRefs = _getShaderRefs
Node.getActiveShaderRefs = _getActiveShaderRefs
#
# PropertySet
#
def _setPropertyValue(self, name, value, typeString = ''):
"""Set the typed value of a property by its name, creating a child element
to hold the property if needed."""
method = getattr(self.__class__, "_setPropertyValue" + getTypeString(value))
return method(self, name, value, typeString)
def _getPropertyValue(self, name, target = ''):
"""Return the typed value of a property by its name. If the given property
is not found, then None is returned."""
value = self._getPropertyValue(name)
return value.getData() if value else None
PropertySet.setPropertyValue = _setPropertyValue
PropertySet.getPropertyValue = _getPropertyValue
#
# GeomInfo
#
def _setGeomPropValue(self, name, value, typeString = ''):
"""Set the value of a geomprop by its name, creating a child element
to hold the geomprop if needed."""
method = getattr(self.__class__, "_setGeomPropValue" + getTypeString(value))
return method(self, name, value, typeString)
def _addGeomAttr(self, name):
"(Deprecated) Add a geomprop to this element."
warnings.warn("This function is deprecated; call GeomInfo.addGeomProp() instead", DeprecationWarning, stacklevel = 2)
return self.addGeomProp(name)
def _setGeomAttrValue(self, name, value, typeString = ''):
"(Deprecated) Set the value of a geomattr by its name."
warnings.warn("This function is deprecated; call GeomInfo.setGeomPropValue() instead", DeprecationWarning, stacklevel = 2)
return self.setGeomPropValue(name, value, typeString)
GeomInfo.setGeomPropValue = _setGeomPropValue
GeomInfo.addGeomAttr = _addGeomAttr
GeomInfo.setGeomAttrValue = _setGeomAttrValue
#
# Document
#
def _addMaterial(self, name):
"""(Deprecated) Add a material element to the document."""
warnings.warn("This function is deprecated; call Document.addMaterialNode() instead.", DeprecationWarning, stacklevel = 2)
return self.addMaterialNode(name)
def _getMaterials(self):
"""(Deprecated) Return a vector of all materials in the document."""
warnings.warn("This function is deprecated; call Document.getMaterialNodes() instead.", DeprecationWarning, stacklevel = 2)
return self.getMaterialNodes()
Document.addMaterial = _addMaterial
Document.getMaterials = _getMaterials
#
# Value
#
def _typeToName(t):
"(Deprecated) Return the MaterialX type string associated with the given Python type."
warnings.warn("This function is deprecated; call MaterialX.getTypeString instead.", DeprecationWarning, stacklevel = 2)
return getTypeString(t())
def _valueToString(value):
"(Deprecated) Convert a Python value to its corresponding MaterialX value string."
warnings.warn("This function is deprecated; call MaterialX.getValueString instead.", DeprecationWarning, stacklevel = 2)
return getValueString(value)
def _stringToValue(string, t):
"(Deprecated) Convert a MaterialX value string and Python type to the corresponding Python value."
warnings.warn("This function is deprecated; call MaterialX.createValueFromStrings instead.", DeprecationWarning, stacklevel = 2)
return createValueFromStrings(string, getTypeString(t()))
typeToName = _typeToName
valueToString = _valueToString
stringToValue = _stringToValue
#
# XmlIo
#
readFromXmlFile = readFromXmlFileBase
#
# Default Data Paths
#
def getDefaultDataSearchPath():
"""
Return the default data search path.
"""
return FileSearchPath(os.path.dirname(__file__))
def getDefaultDataLibraryFolders():
"""
Return list of default data library folders
"""
return [ 'libraries' ]

View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python
'''
Unit tests for shader generation in MaterialX Python.
'''
import os, unittest
import MaterialX as mx
import MaterialX.PyMaterialXGenShader as mx_gen_shader
import MaterialX.PyMaterialXGenOsl as mx_gen_osl
class TestGenShader(unittest.TestCase):
def test_ShaderInterface(self):
doc = mx.createDocument()
searchPath = mx.getDefaultDataSearchPath()
mx.loadLibraries(mx.getDefaultDataLibraryFolders(), searchPath, doc)
exampleName = u"shader_interface"
# Create a nodedef taking three color3 and producing another color3
nodeDef = doc.addNodeDef("ND_foo", "color3", "foo")
fooInputA = nodeDef.addInput("a", "color3")
fooInputB = nodeDef.addInput("b", "color3")
fooOutput = nodeDef.getOutput("out")
fooInputA.setValue(mx.Color3(1.0, 1.0, 0.0))
fooInputB.setValue(mx.Color3(0.8, 0.1, 0.1))
# Create an implementation graph for the nodedef performing
# a multiplication of the three colors.
nodeGraph = doc.addNodeGraph("IMP_foo")
nodeGraph.setAttribute("nodedef", nodeDef.getName())
output = nodeGraph.addOutput(fooOutput.getName(), "color3")
mult1 = nodeGraph.addNode("multiply", "mult1", "color3")
in1 = mult1.addInput("in1", "color3")
in1.setInterfaceName(fooInputA.getName())
in2 = mult1.addInput("in2", "color3")
in2.setInterfaceName(fooInputB.getName())
output.setConnectedNode(mult1)
doc.addNode("foo", "foo1", "color3")
output = doc.addOutput("foo_test", "color3");
output.setNodeName("foo1");
output.setAttribute("output", "o");
# Test for target
targetDefs = doc.getTargetDefs()
self.assertTrue(len(targetDefs))
shadergen = mx_gen_osl.OslShaderGenerator.create()
target = shadergen.getTarget()
foundTarget = next((
t for t in targetDefs
if t.getName() == target), None)
self.assertTrue(foundTarget)
context = mx_gen_shader.GenContext(shadergen)
context.registerSourceCodeSearchPath(searchPath)
shadergen.registerTypeDefs(doc);
# Test generator with complete mode
context.getOptions().shaderInterfaceType = mx_gen_shader.ShaderInterfaceType.SHADER_INTERFACE_COMPLETE;
shader = shadergen.generate(exampleName, output, context);
self.assertTrue(shader)
self.assertTrue(len(shader.getSourceCode(mx_gen_shader.PIXEL_STAGE)) > 0)
ps = shader.getStage(mx_gen_shader.PIXEL_STAGE);
uniforms = ps.getUniformBlock(mx_gen_osl.OSL_UNIFORMS)
self.assertTrue(uniforms.size() == 2)
outputs = ps.getOutputBlock(mx_gen_osl.OSL_OUTPUTS)
self.assertTrue(outputs.size() == 1)
self.assertTrue(outputs[0].getName() == output.getName())
file = open(shader.getName() + "_complete.osl", "w+")
file.write(shader.getSourceCode(mx_gen_shader.PIXEL_STAGE))
file.close()
os.remove(shader.getName() + "_complete.osl");
# Test generator with reduced mode
context.getOptions().shaderInterfaceType = mx_gen_shader.ShaderInterfaceType.SHADER_INTERFACE_REDUCED;
shader = shadergen.generate(exampleName, output, context);
self.assertTrue(shader)
self.assertTrue(len(shader.getSourceCode(mx_gen_shader.PIXEL_STAGE)) > 0)
ps = shader.getStage(mx_gen_shader.PIXEL_STAGE);
uniforms = ps.getUniformBlock(mx_gen_osl.OSL_UNIFORMS)
self.assertTrue(uniforms.size() == 0)
outputs = ps.getOutputBlock(mx_gen_osl.OSL_OUTPUTS)
self.assertTrue(outputs.size() == 1)
self.assertTrue(outputs[0].getName() == output.getName())
file = open(shader.getName() + "_reduced.osl", "w+")
file.write(shader.getSourceCode(mx_gen_shader.PIXEL_STAGE))
file.close()
os.remove(shader.getName() + "_reduced.osl");
# Define a custom attribute
customAttribute = doc.addAttributeDef("AD_attribute_node_name");
self.assertIsNotNone(customAttribute)
customAttribute.setType("string");
customAttribute.setAttrName("node_name");
customAttribute.setExportable(True);
# Define a nodedef referencing the custom attribute.
stdSurfNodeDef = doc.getNodeDef("ND_standard_surface_surfaceshader");
self.assertIsNotNone(stdSurfNodeDef)
stdSurfNodeDef.setAttribute("node_name", "Standard_Surface_Number_1");
self.assertTrue(stdSurfNodeDef.getAttribute("node_name") == "Standard_Surface_Number_1")
stdSurf1 = doc.addNodeInstance(stdSurfNodeDef, "standardSurface1");
self.assertIsNotNone(stdSurf1)
# Register shader metadata
shadergen.registerShaderMetadata(doc, context);
# Generate and test that attribute is in the code
context.getOptions().shaderInterfaceType = mx_gen_shader.ShaderInterfaceType.SHADER_INTERFACE_COMPLETE;
shader = shadergen.generate(stdSurf1.getName(), stdSurf1, context);
self.assertIsNotNone(shader)
code = shader.getSourceCode(mx_gen_shader.PIXEL_STAGE)
self.assertTrue('Standard_Surface_Number_1' in code)
self.assertTrue('node_name' in code)
print()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,517 @@
#!/usr/bin/env python
'''
Unit tests for MaterialX Python.
'''
import math, os, unittest
import MaterialX as mx
#--------------------------------------------------------------------------------
_testValues = (1,
True,
1.0,
mx.Color3(0.1, 0.2, 0.3),
mx.Color4(0.1, 0.2, 0.3, 0.4),
mx.Vector2(1.0, 2.0),
mx.Vector3(1.0, 2.0, 3.0),
mx.Vector4(1.0, 2.0, 3.0, 4.0),
mx.Matrix33(0.0),
mx.Matrix44(1.0),
'value',
[1, 2, 3],
[False, True, False],
[1.0, 2.0, 3.0],
['one', 'two', 'three'])
_fileDir = os.path.dirname(os.path.abspath(__file__))
_libraryDir = os.path.join(_fileDir, '../../libraries/stdlib/')
_exampleDir = os.path.join(_fileDir, '../../resources/Materials/Examples/')
_searchPath = _libraryDir + mx.PATH_LIST_SEPARATOR + _exampleDir
_libraryFilenames = ('stdlib_defs.mtlx',
'stdlib_ng.mtlx')
_exampleFilenames = ('StandardSurface/standard_surface_brass_tiled.mtlx',
'StandardSurface/standard_surface_brick_procedural.mtlx',
'StandardSurface/standard_surface_carpaint.mtlx',
'StandardSurface/standard_surface_marble_solid.mtlx',
'StandardSurface/standard_surface_look_brass_tiled.mtlx',
'UsdPreviewSurface/usd_preview_surface_gold.mtlx',
'UsdPreviewSurface/usd_preview_surface_plastic.mtlx')
_epsilon = 1e-4
#--------------------------------------------------------------------------------
class TestMaterialX(unittest.TestCase):
def test_Globals(self):
self.assertTrue(mx.__version__ == mx.getVersionString())
def test_DataTypes(self):
for value in _testValues:
valueString = mx.getValueString(value)
typeString = mx.getTypeString(value)
newValue = mx.createValueFromStrings(valueString, typeString)
self.assertTrue(newValue == value)
self.assertTrue(mx.getTypeString(newValue) == typeString)
def test_Vectors(self):
v1 = mx.Vector3(1, 2, 3)
v2 = mx.Vector3(2, 4, 6)
# Indexing operators
self.assertTrue(v1[2] == 3)
v1[2] = 4
self.assertTrue(v1[2] == 4)
v1[2] = 3
# Component-wise operators
self.assertTrue(v2 + v1 == mx.Vector3(3, 6, 9))
self.assertTrue(v2 - v1 == mx.Vector3(1, 2, 3))
self.assertTrue(v2 * v1 == mx.Vector3(2, 8, 18))
self.assertTrue(v2 / v1 == mx.Vector3(2, 2, 2))
v2 += v1
self.assertTrue(v2 == mx.Vector3(3, 6, 9))
v2 -= v1
self.assertTrue(v2 == mx.Vector3(2, 4, 6))
v2 *= v1
self.assertTrue(v2 == mx.Vector3(2, 8, 18))
v2 /= v1
self.assertTrue(v2 == mx.Vector3(2, 4, 6))
self.assertTrue(v1 * 2 == v2)
self.assertTrue(v2 / 2 == v1)
# Unary operation
self.assertTrue(-v1 == mx.Vector3(-1, -2, -3))
v1 *= -1
self.assertTrue(+v1 == mx.Vector3(-1, -2, -3))
v1 *= -1
# Geometric methods
v3 = mx.Vector4(4)
self.assertTrue(v3.getMagnitude() == 8)
self.assertTrue(v3.getNormalized().getMagnitude() == 1)
self.assertTrue(v1.dot(v2) == 28)
self.assertTrue(v1.cross(v2) == mx.Vector3())
# Vector copy
v4 = v2.copy()
self.assertTrue(v4 == v2)
v4[0] += 1;
self.assertTrue(v4 != v2)
def test_Matrices(self):
# Translation and scale
trans = mx.Matrix44.createTranslation(mx.Vector3(1, 2, 3))
scale = mx.Matrix44.createScale(mx.Vector3(2))
self.assertTrue(trans == mx.Matrix44(1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
1, 2, 3, 1))
self.assertTrue(scale == mx.Matrix44(2, 0, 0, 0,
0, 2, 0, 0,
0, 0, 2, 0,
0, 0, 0, 1))
# Indexing operators
self.assertTrue(trans[3, 2] == 3)
trans[3, 2] = 4
self.assertTrue(trans[3, 2] == 4)
trans[3, 2] = 3
# Matrix methods
self.assertTrue(trans.getTranspose() == mx.Matrix44(1, 0, 0, 1,
0, 1, 0, 2,
0, 0, 1, 3,
0, 0, 0, 1))
self.assertTrue(scale.getTranspose() == scale)
self.assertTrue(trans.getDeterminant() == 1)
self.assertTrue(scale.getDeterminant() == 8)
self.assertTrue(trans.getInverse() ==
mx.Matrix44.createTranslation(mx.Vector3(-1, -2, -3)))
# Matrix copy
trans2 = trans.copy()
self.assertTrue(trans2 == trans)
trans2[0, 0] += 1;
self.assertTrue(trans2 != trans)
# Matrix product
prod1 = trans * scale
prod2 = scale * trans
prod3 = trans * 2
prod4 = trans.copy()
prod4 *= scale
self.assertTrue(prod1 == mx.Matrix44(2, 0, 0, 0,
0, 2, 0, 0,
0, 0, 2, 0,
2, 4, 6, 1))
self.assertTrue(prod2 == mx.Matrix44(2, 0, 0, 0,
0, 2, 0, 0,
0, 0, 2, 0,
1, 2, 3, 1))
self.assertTrue(prod3 == mx.Matrix44(2, 0, 0, 0,
0, 2, 0, 0,
0, 0, 2, 0,
2, 4, 6, 2))
self.assertTrue(prod4 == prod1)
# Matrix division
quot1 = prod1 / scale
quot2 = prod2 / trans
quot3 = prod3 / 2
quot4 = quot1.copy()
quot4 /= trans
self.assertTrue(quot1 == trans)
self.assertTrue(quot2 == scale)
self.assertTrue(quot3 == trans)
self.assertTrue(quot4 == mx.Matrix44.IDENTITY)
# Unary operation
self.assertTrue(-trans == mx.Matrix44(-1, 0, 0, 0,
0, -1, 0, 0,
0, 0, -1, 0,
-1, -2, -3, -1))
trans *= -1
self.assertTrue(+trans == mx.Matrix44(-1, 0, 0, 0,
0, -1, 0, 0,
0, 0, -1, 0,
-1, -2, -3, -1))
trans *= -1
# 2D rotation
rot1 = mx.Matrix33.createRotation(math.pi / 2)
rot2 = mx.Matrix33.createRotation(math.pi)
self.assertTrue((rot1 * rot1).isEquivalent(rot2, _epsilon))
self.assertTrue(rot2.isEquivalent(
mx.Matrix33.createScale(mx.Vector2(-1)), _epsilon))
self.assertTrue((rot2 * rot2).isEquivalent(mx.Matrix33.IDENTITY, _epsilon))
# 3D rotation
rotX = mx.Matrix44.createRotationX(math.pi)
rotY = mx.Matrix44.createRotationY(math.pi)
rotZ = mx.Matrix44.createRotationZ(math.pi)
self.assertTrue((rotX * rotY).isEquivalent(
mx.Matrix44.createScale(mx.Vector3(-1, -1, 1)), _epsilon))
self.assertTrue((rotX * rotZ).isEquivalent(
mx.Matrix44.createScale(mx.Vector3(-1, 1, -1)), _epsilon))
self.assertTrue((rotY * rotZ).isEquivalent(
mx.Matrix44.createScale(mx.Vector3(1, -1, -1)), _epsilon))
def test_BuildDocument(self):
# Create a document.
doc = mx.createDocument()
# Create a node graph with constant and image sources.
nodeGraph = doc.addNodeGraph()
self.assertTrue(nodeGraph)
self.assertRaises(LookupError, doc.addNodeGraph, nodeGraph.getName())
constant = nodeGraph.addNode('constant')
image = nodeGraph.addNode('image')
# Connect sources to outputs.
output1 = nodeGraph.addOutput()
output2 = nodeGraph.addOutput()
output1.setConnectedNode(constant)
output2.setConnectedNode(image)
self.assertTrue(output1.getConnectedNode() == constant)
self.assertTrue(output2.getConnectedNode() == image)
self.assertTrue(output1.getUpstreamElement() == constant)
self.assertTrue(output2.getUpstreamElement() == image)
# Set constant node color.
color = mx.Color3(0.1, 0.2, 0.3)
constant.setInputValue('value', color)
self.assertTrue(constant.getInputValue('value') == color)
# Set image node file.
file = 'image1.tif'
image.setInputValue('file', file, 'filename')
self.assertTrue(image.getInputValue('file') == file)
# Create a custom nodedef.
nodeDef = doc.addNodeDef('nodeDef1', 'float', 'turbulence3d')
nodeDef.setInputValue('octaves', 3)
nodeDef.setInputValue('lacunarity', 2.0)
nodeDef.setInputValue('gain', 0.5)
# Reference the custom nodedef.
custom = nodeGraph.addNode('turbulence3d', 'turbulence1', 'float')
self.assertTrue(custom.getInputValue('octaves') == 3)
custom.setInputValue('octaves', 5)
self.assertTrue(custom.getInputValue('octaves') == 5)
# Test scoped attributes.
nodeGraph.setFilePrefix('folder/')
nodeGraph.setColorSpace('lin_rec709')
self.assertTrue(image.getInput('file').getResolvedValueString() == 'folder/image1.tif')
self.assertTrue(constant.getActiveColorSpace() == 'lin_rec709')
# Create a simple shader interface.
simpleSrf = doc.addNodeDef('', 'surfaceshader', 'simpleSrf')
simpleSrf.setInputValue('diffColor', mx.Color3(1.0))
simpleSrf.setInputValue('specColor', mx.Color3(0.0))
roughness = simpleSrf.setInputValue('roughness', 0.25)
self.assertTrue(roughness.getIsUniform() == False)
roughness.setIsUniform(True);
self.assertTrue(roughness.getIsUniform() == True)
# Instantiate shader and material nodes.
shaderNode = doc.addNodeInstance(simpleSrf)
materialNode = doc.addMaterialNode('', shaderNode)
# Bind the diffuse color input to the constant color output.
shaderNode.setConnectedOutput('diffColor', output1)
self.assertTrue(shaderNode.getUpstreamElement() == constant)
# Bind the roughness input to a value.
instanceRoughness = shaderNode.setInputValue('roughness', 0.5)
self.assertTrue(instanceRoughness.getValue() == 0.5)
self.assertTrue(instanceRoughness.getDefaultValue() == 0.25)
# Create a look for the material.
look = doc.addLook()
self.assertTrue(len(doc.getLooks()) == 1)
# Bind the material to a geometry string.
matAssign1 = look.addMaterialAssign("matAssign1", materialNode.getName())
matAssign1.setGeom("/robot1")
self.assertTrue(matAssign1.getReferencedMaterial() == materialNode)
self.assertTrue(len(mx.getGeometryBindings(materialNode, "/robot1")) == 1)
self.assertTrue(len(mx.getGeometryBindings(materialNode, "/robot2")) == 0)
# Bind the material to a collection.
matAssign2 = look.addMaterialAssign("matAssign2", materialNode.getName())
collection = doc.addCollection()
collection.setIncludeGeom("/robot2")
collection.setExcludeGeom("/robot2/left_arm")
matAssign2.setCollection(collection)
self.assertTrue(len(mx.getGeometryBindings(materialNode, "/robot2")) == 1)
self.assertTrue(len(mx.getGeometryBindings(materialNode, "/robot2/right_arm")) == 1)
self.assertTrue(len(mx.getGeometryBindings(materialNode, "/robot2/left_arm")) == 0)
# Create a property assignment.
propertyAssign = look.addPropertyAssign()
propertyAssign.setProperty("twosided")
propertyAssign.setGeom("/robot1")
propertyAssign.setValue(True)
self.assertTrue(propertyAssign.getProperty() == "twosided")
self.assertTrue(propertyAssign.getGeom() == "/robot1")
self.assertTrue(propertyAssign.getValue() == True)
# Create a property set assignment.
propertySet = doc.addPropertySet()
propertySet.setPropertyValue('matte', False)
self.assertTrue(propertySet.getPropertyValue('matte') == False)
propertySetAssign = look.addPropertySetAssign()
propertySetAssign.setPropertySet(propertySet)
propertySetAssign.setGeom('/robot1')
self.assertTrue(propertySetAssign.getPropertySet() == propertySet)
self.assertTrue(propertySetAssign.getGeom() == '/robot1')
# Create a variant set.
variantSet = doc.addVariantSet()
variantSet.addVariant("original")
variantSet.addVariant("damaged")
self.assertTrue(len(variantSet.getVariants()) == 2)
# Validate the document.
valid, message = doc.validate()
self.assertTrue(valid, 'Document returned validation warnings: ' + message)
# Disconnect outputs from sources.
output1.setConnectedNode(None)
output2.setConnectedNode(None)
self.assertTrue(output1.getConnectedNode() == None)
self.assertTrue(output2.getConnectedNode() == None)
def test_TraverseGraph(self):
# Create a document.
doc = mx.createDocument()
# Create a node graph with the following structure:
#
# [image1] [constant] [image2]
# \ / |
# [multiply] [contrast] [noise3d]
# \____________ | ____________/
# [mix]
# |
# [output]
#
nodeGraph = doc.addNodeGraph()
image1 = nodeGraph.addNode('image')
image2 = nodeGraph.addNode('image')
constant = nodeGraph.addNode('constant')
multiply = nodeGraph.addNode('multiply')
contrast = nodeGraph.addNode('contrast')
noise3d = nodeGraph.addNode('noise3d')
mix = nodeGraph.addNode('mix')
output = nodeGraph.addOutput()
multiply.setConnectedNode('in1', image1)
multiply.setConnectedNode('in2', constant)
contrast.setConnectedNode('in', image2)
mix.setConnectedNode('fg', multiply)
mix.setConnectedNode('bg', contrast)
mix.setConnectedNode('mask', noise3d)
output.setConnectedNode(mix)
# Validate the document.
valid, message = doc.validate()
self.assertTrue(valid, 'Document returned validation warnings: ' + message)
# Traverse the document tree (implicit iterator).
nodeCount = 0
for elem in doc.traverseTree():
if elem.isA(mx.Node):
nodeCount += 1
self.assertTrue(nodeCount == 7)
# Traverse the document tree (explicit iterator).
nodeCount = 0
maxElementDepth = 0
treeIter = doc.traverseTree()
for elem in treeIter:
if elem.isA(mx.Node):
nodeCount += 1
maxElementDepth = max(maxElementDepth, treeIter.getElementDepth())
self.assertTrue(nodeCount == 7)
self.assertTrue(maxElementDepth == 3)
# Traverse the document tree (prune subtree).
nodeCount = 0
treeIter = doc.traverseTree()
for elem in treeIter:
if elem.isA(mx.Node):
nodeCount += 1
if elem.isA(mx.NodeGraph):
treeIter.setPruneSubtree(True)
self.assertTrue(nodeCount == 0)
# Traverse upstream from the graph output (implicit iterator).
nodeCount = 0
for edge in output.traverseGraph():
upstreamElem = edge.getUpstreamElement()
connectingElem = edge.getConnectingElement()
downstreamElem = edge.getDownstreamElement()
if upstreamElem.isA(mx.Node):
nodeCount += 1
if downstreamElem.isA(mx.Node):
self.assertTrue(connectingElem.isA(mx.Input))
self.assertTrue(nodeCount == 7)
# Traverse upstream from the graph output (explicit iterator).
nodeCount = 0
maxElementDepth = 0
maxNodeDepth = 0
graphIter = output.traverseGraph()
for edge in graphIter:
upstreamElem = edge.getUpstreamElement()
connectingElem = edge.getConnectingElement()
downstreamElem = edge.getDownstreamElement()
if upstreamElem.isA(mx.Node):
nodeCount += 1
maxElementDepth = max(maxElementDepth, graphIter.getElementDepth())
maxNodeDepth = max(maxNodeDepth, graphIter.getNodeDepth())
self.assertTrue(nodeCount == 7)
self.assertTrue(maxElementDepth == 3)
self.assertTrue(maxNodeDepth == 3)
# Traverse upstream from the graph output (prune subgraph).
nodeCount = 0
graphIter = output.traverseGraph()
for edge in graphIter:
upstreamElem = edge.getUpstreamElement()
connectingElem = edge.getConnectingElement()
downstreamElem = edge.getDownstreamElement()
if upstreamElem.isA(mx.Node):
nodeCount += 1
if upstreamElem.getCategory() == 'multiply':
graphIter.setPruneSubgraph(True)
self.assertTrue(nodeCount == 5)
# Create and detect a cycle.
multiply.setConnectedNode('in2', mix)
self.assertTrue(output.hasUpstreamCycle())
self.assertFalse(doc.validate()[0])
multiply.setConnectedNode('in2', constant)
self.assertFalse(output.hasUpstreamCycle())
self.assertTrue(doc.validate()[0])
# Create and detect a loop.
contrast.setConnectedNode('in', contrast)
self.assertTrue(output.hasUpstreamCycle())
self.assertFalse(doc.validate()[0])
contrast.setConnectedNode('in', image2)
self.assertFalse(output.hasUpstreamCycle())
self.assertTrue(doc.validate()[0])
def test_Xmlio(self):
# Read the standard library.
libs = []
for filename in _libraryFilenames:
lib = mx.createDocument()
mx.readFromXmlFile(lib, filename, _searchPath)
libs.append(lib)
# Declare write predicate for write filter test
def skipLibraryElement(elem):
return not elem.hasSourceUri()
# Read and validate each example document.
for filename in _exampleFilenames:
doc = mx.createDocument()
mx.readFromXmlFile(doc, filename, _searchPath)
valid, message = doc.validate()
self.assertTrue(valid, filename + ' returned validation warnings: ' + message)
# Copy the document.
copiedDoc = doc.copy()
self.assertTrue(copiedDoc == doc)
copiedDoc.addLook()
self.assertTrue(copiedDoc != doc)
# Traverse the document tree.
valueElementCount = 0
for elem in doc.traverseTree():
if elem.isA(mx.ValueElement):
valueElementCount += 1
self.assertTrue(valueElementCount > 0)
# Serialize to XML.
writeOptions = mx.XmlWriteOptions()
writeOptions.writeXIncludeEnable = False
xmlString = mx.writeToXmlString(doc, writeOptions)
# Verify that the serialized document is identical.
writtenDoc = mx.createDocument()
mx.readFromXmlString(writtenDoc, xmlString)
self.assertTrue(writtenDoc == doc)
# Combine document with the standard library.
doc2 = doc.copy()
for lib in libs:
doc2.importLibrary(lib)
self.assertTrue(doc2.validate()[0])
# Write without definitions
writeOptions.writeXIncludeEnable = False
writeOptions.elementPredicate = skipLibraryElement
result = mx.writeToXmlString(doc2, writeOptions)
doc3 = mx.createDocument()
mx.readFromXmlString(doc3, result)
self.assertTrue(len(doc3.getNodeDefs()) == 0)
# Read the same document twice, and verify that duplicate elements
# are skipped.
doc = mx.createDocument()
filename = 'StandardSurface/standard_surface_carpaint.mtlx'
mx.readFromXmlFile(doc, filename, _searchPath)
mx.readFromXmlFile(doc, filename, _searchPath)
self.assertTrue(doc.validate()[0])
#--------------------------------------------------------------------------------
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,2 @@
@echo off
python tests_to_html.py -i1 ../../build %* -d

View File

@@ -0,0 +1,230 @@
#!/usr/bin/python
import sys
import os
import datetime
import argparse
try:
# Install pillow via pip to enable image differencing and statistics.
from PIL import Image, ImageChops, ImageStat
DIFF_ENABLED = True
except Exception:
DIFF_ENABLED = False
def computeDiff(image1Path, image2Path, imageDiffPath):
try:
if os.path.exists(imageDiffPath):
os.remove(imageDiffPath)
if not os.path.exists(image1Path):
print ("Image diff input missing: " + image1Path)
return
if not os.path.exists(image2Path):
print ("Image diff input missing: " + image2Path)
return
image1 = Image.open(image1Path).convert('RGB')
image2 = Image.open(image2Path).convert('RGB')
diff = ImageChops.difference(image1, image2)
diff.save(imageDiffPath)
diffStat = ImageStat.Stat(diff)
return sum(diffStat.rms) / (3.0 * 255.0)
except Exception:
if os.path.exists(imageDiffPath):
os.remove(imageDiffPath)
print ("Failed to create image diff between: " + image1Path + ", " + image2Path)
def main(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('-i1', '--inputdir1', dest='inputdir1', action='store', help='Input directory', default=".")
parser.add_argument('-i2', '--inputdir2', dest='inputdir2', action='store', help='Second input directory', default="")
parser.add_argument('-i3', '--inputdir3', dest='inputdir3', action='store', help='Third input directory', default="")
parser.add_argument('-o', '--outputfile', dest='outputfile', action='store', help='Output file name', default="tests.html")
parser.add_argument('-d', '--diff', dest='CREATE_DIFF', action='store_true', help='Perform image diff', default=False)
parser.add_argument('-t', '--timestamp', dest='ENABLE_TIMESTAMPS', action='store_true', help='Write image timestamps', default=False)
parser.add_argument('-w', '--imagewidth', type=int, dest='imagewidth', action='store', help='Set image display width', default=256)
parser.add_argument('-ht', '--imageheight', type=int, dest='imageheight', action='store', help='Set image display height', default=256)
parser.add_argument('-cp', '--cellpadding', type=int, dest='cellpadding', action='store', help='Set table cell padding', default=0)
parser.add_argument('-tb', '--tableborder', type=int, dest='tableborder', action='store', help='Table border width. 0 means no border', default=3)
parser.add_argument('-l1', '--lang1', dest='lang1', action='store', help='First target language for comparison. Default is glsl', default="glsl")
parser.add_argument('-l2', '--lang2', dest='lang2', action='store', help='Second target language for comparison. Default is osl', default="osl")
parser.add_argument('-l3', '--lang3', dest='lang3', action='store', help='Third target language for comparison. Default is empty', default="")
parser.add_argument('-e', '--error', dest='error', action='store', help='Filter out results with RMS less than this. Negative means all results are kept.', default=-1, type=float)
args = parser.parse_args(args)
fh = open(args.outputfile,"w+")
fh.write("<html>\n")
fh.write("<style>\n")
fh.write("td {")
fh.write(" padding: " + str(args.cellpadding) + ";")
fh.write(" border: " + str(args.tableborder) + "px solid black;")
fh.write("}")
fh.write("table, tbody, th, .td_image {")
fh.write(" border-collapse: collapse;")
fh.write(" padding: 0;")
fh.write(" margin: 0;")
fh.write("}")
fh.write("</style>")
fh.write("<body>\n")
if args.inputdir1 == ".":
args.inputdir1 = os.getcwd()
if args.inputdir2 == ".":
args.inputdir2 = os.getcwd()
elif args.inputdir2 == "":
args.inputdir2 = args.inputdir1
if args.inputdir3 == ".":
args.inputdir3 = os.getcwd()
elif args.inputdir3 == "":
args.inputdir3 = args.inputdir1
useThirdLang = args.lang3
if useThirdLang:
fh.write("<h3>" + args.lang1 + " (in: " + args.inputdir1 + ") vs "+ args.lang2 + " (in: " + args.inputdir2 + ") vs "+ args.lang3 + " (in: " + args.inputdir3 + ")</h3>\n")
else:
fh.write("<h3>" + args.lang1 + " (in: " + args.inputdir1 + ") vs "+ args.lang2 + " (in: " + args.inputdir2 + ")</h3>\n")
if not DIFF_ENABLED and args.CREATE_DIFF:
print("--diff argument ignored. Diff utility not installed.")
# Remove potential trailing path separators
if args.inputdir1[-1:] == '/' or args.inputdir1[-1:] == '\\':
args.inputdir1 = args.inputdir1[:-1]
if args.inputdir2[-1:] == '/' or args.inputdir2[-1:] == '\\':
args.inputdir2 = args.inputdir2[:-1]
if args.inputdir3[-1:] == '/' or args.inputdir3[-1:] == '\\':
args.inputdir3 = args.inputdir3[:-1]
# Get all source files
langFiles1 = []
langPaths1 = []
for subdir, _, files in os.walk(args.inputdir1):
for curFile in files:
if curFile.endswith(args.lang1 + ".png"):
langFiles1.append(curFile)
langPaths1.append(subdir)
# Get all destination files, matching source files
langFiles2 = []
langPaths2 = []
langFiles3 = []
langPaths3 = []
preFixLen: int = len(args.inputdir1) + 1 # including the path separator
postFix: str = args.lang1 + ".png"
for file1, path1 in zip(langFiles1, langPaths1):
# Allow for just one language to be shown if source and dest are the same.
# Otherwise add in equivalent name with dest language replacement if
# pointing to the same directory
if args.inputdir1 != args.inputdir2 or args.lang1 != args.lang2:
file2 = file1[:-len(postFix)] + args.lang2 + ".png"
path2 = os.path.join(args.inputdir2, path1[len(args.inputdir1)+1:])
else:
file2 = ""
path2 = None
langFiles2.append(file2)
langPaths2.append(path2)
if useThirdLang:
file3 = file1[:-len(postFix)] + args.lang3 + ".png"
path3 = os.path.join(args.inputdir2, path1[len(args.inputdir1)+1:])
else:
file3 = ""
path3 = None
langFiles3.append(file3)
langPaths3.append(path3)
if langFiles1:
curPath = ""
for file1, file2, file3, path1, path2, path3 in zip(langFiles1, langFiles2, langFiles3, langPaths1, langPaths2, langPaths3):
fullPath1 = os.path.join(path1, file1) if file1 else None
fullPath2 = os.path.join(path2, file2) if file2 else None
fullPath3 = os.path.join(path3, file3) if file3 else None
diffPath1 = diffPath2 = diffPath3 = None
diffRms1 = diffRms2 = diffRms3 = None
if file1 and file2 and DIFF_ENABLED and args.CREATE_DIFF:
diffPath1 = fullPath1[0:-8] + "_" + args.lang1 + "-1_vs_" + args.lang2 + "-2_diff.png"
diffRms1 = computeDiff(fullPath1, fullPath2, diffPath1)
if useThirdLang and file1 and file3 and DIFF_ENABLED and args.CREATE_DIFF:
diffPath2 = fullPath1[0:-8] + "_" + args.lang1 + "-1_vs_" + args.lang3 + "-3_diff.png"
diffRms2 = computeDiff(fullPath1, fullPath3, diffPath2)
diffPath3 = fullPath1[0:-8] + "_" + args.lang2 + "-2_vs_" + args.lang3 + "-3_diff.png"
diffRms3 = computeDiff(fullPath2, fullPath3, diffPath3)
if args.error >= 0:
ok1 = (not diffPath1) or (not diffRms1) or (diffRms1 and diffRms1 <= args.error)
ok2 = (not diffPath2) or (not diffRms2) or (diffRms2 and diffRms2 <= args.error)
ok3 = (not diffPath3) or (not diffRms3) or (diffRms3 and diffRms3 <= args.error)
if ok1 and ok2 and ok3:
continue
if curPath != path1:
if curPath != "":
fh.write("</table>\n")
fh.write("<p>" + os.path.normpath(path1) + ":</p>\n")
fh.write("<table>\n")
curPath = path1
def prependFileUri(filepath: str) -> str:
if os.path.isabs(filepath):
return 'file:///' + filepath
else:
return filepath
fh.write("<tr>\n")
if fullPath1:
fh.write("<td class='td_image'><img src='" + prependFileUri(fullPath1) + "' height='" + str(args.imageheight) + "' width='" + str(args.imagewidth) + "' loading='lazy' style='background-color:black;'/></td>\n")
if fullPath2:
fh.write("<td class='td_image'><img src='" + prependFileUri(fullPath2) + "' height='" + str(args.imageheight) + "' width='" + str(args.imagewidth) + "' loading='lazy' style='background-color:black;'/></td>\n")
if fullPath3:
fh.write("<td class='td_image'><img src='" + prependFileUri(fullPath3) + "' height='" + str(args.imageheight) + "' width='" + str(args.imagewidth) + "' loading='lazy' style='background-color:black;'/></td>\n")
if diffPath1:
fh.write("<td class='td_image'><img src='" + prependFileUri(diffPath1) + "' height='" + str(args.imageheight) + "' width='" + str(args.imagewidth) + "' loading='lazy' style='background-color:black;'/></td>\n")
if diffPath2:
fh.write("<td class='td_image'><img src='" + prependFileUri(diffPath2) + "' height='" + str(args.imageheight) + "' width='" + str(args.imagewidth) + "' loading='lazy' style='background-color:black;'/></td>\n")
if diffPath3:
fh.write("<td class='td_image'><img src='" + prependFileUri(diffPath3) + "' height='" + str(args.imageheight) + "' width='" + str(args.imagewidth) + "' loading='lazy' style='background-color:black;'/></td>\n")
fh.write("</tr>\n")
fh.write("<tr>\n")
if fullPath1:
fh.write("<td align='center'>" + file1)
if args.ENABLE_TIMESTAMPS and os.path.isfile(fullPath1):
fh.write("<br>(" + str(datetime.datetime.fromtimestamp(os.path.getmtime(fullPath1))) + ")")
fh.write("</td>\n")
if fullPath2:
fh.write("<td align='center'>" + file2)
if args.ENABLE_TIMESTAMPS and os.path.isfile(fullPath2):
fh.write("<br>(" + str(datetime.datetime.fromtimestamp(os.path.getmtime(fullPath2))) + ")")
fh.write("</td>\n")
if fullPath3:
fh.write("<td align='center'>" + file3)
if args.ENABLE_TIMESTAMPS and os.path.isfile(fullPath3):
fh.write("<br>(" + str(datetime.datetime.fromtimestamp(os.path.getmtime(fullPath3))) + ")")
fh.write("</td>\n")
if diffPath1:
rms = " (RMS " + "%.5f" % diffRms1 + ")" if diffRms1 else ""
fh.write("<td align='center'>" + args.lang1.upper() + " vs. " + args.lang2.upper() + rms + "</td>\n")
if diffPath2:
rms = " (RMS " + "%.5f" % diffRms2 + ")" if diffRms2 else ""
fh.write("<td align='center'>" + args.lang1.upper() + " vs. " + args.lang3.upper() + rms + "</td>\n")
if diffPath3:
rms = " (RMS " + "%.5f" % diffRms3 + ")" if diffRms3 else ""
fh.write("<td align='center'>" + args.lang2.upper() + " vs. " + args.lang3.upper() + rms + "</td>\n")
fh.write("</tr>\n")
fh.write("</table>\n")
fh.write("</body>\n")
fh.write("</html>\n")
if __name__ == "__main__":
main(sys.argv[1:])

View File

@@ -0,0 +1,3 @@
# Python Code Examples
This folder contains example Python scripts that generate, process, and validate material content using the MaterialX API.

View File

@@ -0,0 +1,73 @@
#!/usr/bin/env python
'''
Generate a baked version of each material in the input document, using the TextureBaker class in the MaterialXRenderGlsl library.
'''
import sys, os, argparse
from sys import platform
import MaterialX as mx
from MaterialX import PyMaterialXRender as mx_render
from MaterialX import PyMaterialXRenderGlsl as mx_render_glsl
if platform == "darwin":
from MaterialX import PyMaterialXRenderMsl as mx_render_msl
def main():
parser = argparse.ArgumentParser(description="Generate a baked version of each material in the input document.")
parser.add_argument("--width", dest="width", type=int, default=1024, help="Specify the width of baked textures.")
parser.add_argument("--height", dest="height", type=int, default=1024, help="Specify the height of baked textures.")
parser.add_argument("--hdr", dest="hdr", action="store_true", help="Save images to hdr format.")
parser.add_argument("--average", dest="average", action="store_true", help="Average baked images to generate constant values.")
parser.add_argument("--path", dest="paths", action='append', nargs='+', help="An additional absolute search path location (e.g. '/projects/MaterialX')")
parser.add_argument("--library", dest="libraries", action='append', nargs='+', help="An additional relative path to a custom data library folder (e.g. 'libraries/custom')")
parser.add_argument('--writeDocumentPerMaterial', dest='writeDocumentPerMaterial', type=mx.stringToBoolean, default=True, help='Specify whether to write baked materials to separate MaterialX documents. Default is True')
if platform == "darwin":
parser.add_argument("--glsl", dest="useGlslBackend", default=False, type=bool, help="Set to True to use GLSL backend (default = Metal).")
parser.add_argument(dest="inputFilename", help="Filename of the input document.")
parser.add_argument(dest="outputFilename", help="Filename of the output document.")
opts = parser.parse_args()
# Load standard and custom data libraries.
stdlib = mx.createDocument()
searchPath = mx.getDefaultDataSearchPath()
searchPath.append(os.path.dirname(opts.inputFilename))
libraryFolders = []
if opts.paths:
for pathList in opts.paths:
for path in pathList:
searchPath.append(path)
if opts.libraries:
for libraryList in opts.libraries:
for library in libraryList:
libraryFolders.append(library)
libraryFolders.extend(mx.getDefaultDataLibraryFolders())
mx.loadLibraries(libraryFolders, searchPath, stdlib)
# Read and validate the source document.
doc = mx.createDocument()
try:
mx.readFromXmlFile(doc, opts.inputFilename)
doc.setDataLibrary(stdlib)
except mx.ExceptionFileMissing as err:
print(err)
sys.exit(0)
valid, msg = doc.validate()
if not valid:
print("Validation warnings for input document:")
print(msg)
# Construct the texture baker.
baseType = mx_render.BaseType.FLOAT if opts.hdr else mx_render.BaseType.UINT8
if platform == "darwin" and not opts.useGlslBackend:
baker = mx_render_msl.TextureBaker.create(opts.width, opts.height, baseType)
else:
baker = mx_render_glsl.TextureBaker.create(opts.width, opts.height, baseType)
# Bake materials to textures.
if opts.average:
baker.setAverageImages(True)
baker.writeDocumentPerMaterial(opts.writeDocumentPerMaterial)
baker.bakeAllMaterials(doc, searchPath, opts.outputFilename)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,923 @@
#!/usr/bin/env python
'''
Compare node definitions between a specification Markdown document and a
data library MaterialX document.
Report any differences between the two in their supported node sets, typed
node signatures, and default values.
'''
import argparse
import re
from dataclasses import dataclass, field
from enum import Enum
from itertools import product
from pathlib import Path
import MaterialX as mx
# -----------------------------------------------------------------------------
# Type System
# -----------------------------------------------------------------------------
def loadStandardLibraries():
'''Load and return the standard MaterialX libraries as a document.'''
stdlib = mx.createDocument()
mx.loadLibraries(mx.getDefaultDataLibraryFolders(), mx.getDefaultDataSearchPath(), stdlib)
return stdlib
def getStandardTypes(stdlib):
'''Extract the set of standard type names from library TypeDefs.'''
return {td.getName() for td in stdlib.getTypeDefs()}
def buildTypeGroups(stdlib):
'''
Build type groups from standard library TypeDefs.
Derives colorN, vectorN, matrixNN groups from type naming patterns.
'''
groups = {}
for td in stdlib.getTypeDefs():
name = td.getName()
# Match colorN, vectorN patterns (color3, vector2, etc.)
match = re.match(r'^(color|vector)(\d)$', name)
if match:
groupName = f'{match.group(1)}N'
groups.setdefault(groupName, set()).add(name)
continue
# Match matrixNN pattern (matrix33, matrix44)
match = re.match(r'^matrix(\d)\1$', name)
if match:
groups.setdefault('matrixNN', set()).add(name)
return groups
def buildTypeGroupVariables(typeGroups):
'''Build type group variables (e.g., colorM from colorN) for "must differ" constraints.'''
variables = {}
for groupName in typeGroups:
if groupName.endswith('N') and not groupName.endswith('NN'):
variantName = groupName[:-1] + 'M'
variables[variantName] = groupName
return variables
def parseSpecTypes(typeStr):
'''
Parse a specification type string into (types, typeRef).
Supported patterns:
- Simple types: "float", "color3"
- Comma-separated: "float, color3"
- Union with "or": "BSDF or VDF", "BSDF, EDF, or VDF"
- Type references: "Same as bg", "Same as in1 or float"
'''
if typeStr is None or not typeStr.strip():
return set(), None
typeStr = typeStr.strip()
# Handle "Same as X" and "Same as X or Y" references
sameAsMatch = re.match(r'^Same as\s+`?(\w+)`?(?:\s+or\s+(.+))?$', typeStr, re.IGNORECASE)
if sameAsMatch:
refPort = sameAsMatch.group(1)
extraTypes = sameAsMatch.group(2)
extraSet = set()
if extraTypes:
extraSet, _ = parseSpecTypes(extraTypes)
return extraSet, refPort
# Normalize "or" to comma: "X or Y" -> "X, Y", "X, Y, or Z" -> "X, Y, Z"
normalized = re.sub(r',?\s+or\s+', ', ', typeStr)
result = set()
for t in normalized.split(','):
t = t.strip()
if t:
result.add(t)
return result, None
def expandTypeSet(types, typeGroups, typeGroupVariables):
'''Expand type groups to concrete types. Returns list of (concreteType, groupName) tuples.'''
result = []
for t in types:
if t in typeGroups:
for concrete in typeGroups[t]:
result.append((concrete, t))
elif t in typeGroupVariables:
baseGroup = typeGroupVariables[t]
for concrete in typeGroups[baseGroup]:
result.append((concrete, t))
else:
result.append((t, None))
return result
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------
class MatchType(Enum):
'''Types of signature matches between spec and library.'''
EXACT = 'exact' # Identical inputs and outputs
DIFFERENT_INPUTS = 'different_inputs' # Same outputs but different inputs
class DiffType(Enum):
'''Categories of differences between spec and library, with display labels.'''
# Invalid specification entries
SPEC_COLUMN_MISMATCH = 'Column Count Mismatches in Specification'
SPEC_EMPTY_PORT_NAME = 'Empty Port Names in Specification'
SPEC_UNRECOGNIZED_TYPE = 'Unrecognized Types in Specification'
# Node-level differences
NODE_MISSING_IN_LIBRARY = 'Nodes in Specification but not Data Library'
NODE_MISSING_IN_SPEC = 'Nodes in Data Library but not Specification'
# Signature-level differences
SIGNATURE_DIFFERENT_INPUTS = 'Nodes with Different Input Sets'
SIGNATURE_MISSING_IN_LIBRARY = 'Node Signatures in Specification but not Data Library'
SIGNATURE_MISSING_IN_SPEC = 'Node Signatures in Data Library but not Specification'
# Default value differences
DEFAULT_MISMATCH = 'Default Value Mismatches'
@dataclass
class PortInfo:
'''Information about an input or output port from the specification.'''
name: str
types: set = field(default_factory=set)
typeRef: str = None # For "Same as X" references
default: str = None # Spec default string (before type-specific expansion)
@dataclass(frozen=True)
class NodeSignature:
'''A typed combination of inputs and outputs, corresponding to one nodedef.'''
inputs: tuple # ((name, type), ...) sorted for hashing
outputs: tuple # ((name, type), ...) sorted for hashing
_displayInputs: tuple = None
_displayOutputs: tuple = None
@classmethod
def create(cls, inputs, outputs):
'''Create a NodeSignature from input/output dicts of name -> type.'''
return cls(
inputs=tuple(sorted(inputs.items())),
outputs=tuple(sorted(outputs.items())),
_displayInputs=tuple(inputs.items()),
_displayOutputs=tuple(outputs.items()),
)
def __hash__(self):
return hash((self.inputs, self.outputs))
def __eq__(self, other):
if not isinstance(other, NodeSignature):
return False
return self.inputs == other.inputs and self.outputs == other.outputs
def __str__(self):
insStr = ', '.join(f'{n}:{t}' for n, t in self._displayInputs)
outsStr = ', '.join(f'{n}:{t}' for n, t in self._displayOutputs)
return f'({insStr}) -> {outsStr}'
@dataclass
class NodeInfo:
'''A node and its supported signatures.'''
name: str
signatures: set = field(default_factory=set)
_specInputs: dict = field(default_factory=dict) # For default value comparison
@dataclass
class Difference:
'''A difference found between spec and data library.'''
diffType: DiffType
node: str
port: str = None
signature: NodeSignature = None
extraInLib: tuple = None
extraInSpec: tuple = None
valueType: str = None
specDefault: str = None
libDefault: str = None
def formatDifference(diff):
'''Format a Difference for display, returning a list of lines.'''
# Default mismatch
if diff.diffType == DiffType.DEFAULT_MISMATCH:
return [
f' {diff.node}.{diff.port} ({diff.valueType}):',
f' Signature: {diff.signature}',
f' Spec default: {diff.specDefault}',
f' Data library default: {diff.libDefault}',
]
# Different input sets
if diff.diffType == DiffType.SIGNATURE_DIFFERENT_INPUTS:
lines = [f' {diff.node}: {diff.signature}']
if diff.extraInLib:
extraStr = ', '.join(f'{n}:{t}' for n, t in diff.extraInLib)
lines.append(f' Extra in library: {extraStr}')
if diff.extraInSpec:
extraStr = ', '.join(f'{n}:{t}' for n, t in diff.extraInSpec)
lines.append(f' Extra in spec: {extraStr}')
return lines
# Signature mismatch (missing in spec or library)
if diff.signature:
return [f' {diff.node}: {diff.signature}']
# Spec validation error with port
if diff.port:
return [f' {diff.node}.{diff.port}']
# Node-level difference or simple spec validation error
return [f' {diff.node}']
# -----------------------------------------------------------------------------
# Default Value Utilities
# -----------------------------------------------------------------------------
def buildGeompropNames(stdlib):
'''Extract geomprop names from standard library GeomPropDefs.'''
return {gpd.getName() for gpd in stdlib.getGeomPropDefs()}
def getComponentCount(typeName):
'''Get the number of components for a MaterialX type, or None if unknown.'''
if typeName in ('float', 'integer', 'boolean'):
return 1
# Match colorN, vectorN patterns
match = re.match(r'^(color|vector)(\d)$', typeName)
if match:
return int(match.group(2))
# Match matrixNN pattern
match = re.match(r'^matrix(\d)(\d)$', typeName)
if match:
return int(match.group(1)) * int(match.group(2))
return None
def expandDefaultPlaceholder(placeholder, typeName):
'''Expand a placeholder (0, 1, 0.5) to a type-appropriate value string.'''
count = getComponentCount(typeName)
if count is None:
return None
if placeholder == '0':
if typeName == 'boolean':
return 'false'
return ', '.join(['0'] * count)
if placeholder == '1':
if typeName == 'boolean':
return 'true'
# Identity matrices, not all-ones
if typeName == 'matrix33':
return '1, 0, 0, 0, 1, 0, 0, 0, 1'
if typeName == 'matrix44':
return '1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1'
return ', '.join(['1'] * count)
if placeholder == '0.5':
if typeName in ('integer', 'boolean'):
return None # 0.5 doesn't apply to these types
return ', '.join(['0.5'] * count)
return None
def parseSpecDefault(value, specDefaultNotation):
'''Parse specification default value notation into normalized form.'''
if value is None:
return None
value = value.strip()
return specDefaultNotation.get(value, value)
def expandSpecDefaultToValue(specDefault, valueType, geompropNames):
'''Parse a spec default to a typed MaterialX value. Returns (value, isGeomprop).'''
if specDefault is None or specDefault == '':
return None, False
# Handle geomprop references - these are compared as strings, not typed values
if specDefault in geompropNames:
return specDefault, True
# Expand placeholder values to type-appropriate strings
expansion = expandDefaultPlaceholder(specDefault, valueType)
if expansion is not None:
specDefault = expansion
# Parse to typed value using MaterialX
try:
return mx.createValueFromStrings(specDefault, valueType), False
except Exception:
return None, False
def formatDefaultValue(value, valueType, geompropNames):
'''Format a default value for display using spec notation (__zero__, etc.).'''
if value is None:
return 'None'
# Handle string values (geomprops, empty strings, etc.)
if isinstance(value, str):
if value in geompropNames:
return f'_{value}_'
return '__empty__' if value == '' else value
# Check if value matches a standard constant (__zero__, __one__, __half__)
for placeholder, display in [('0', '__zero__'), ('1', '__one__'), ('0.5', '__half__')]:
expansion = expandDefaultPlaceholder(placeholder, valueType)
if expansion is None:
continue
try:
if value == mx.createValueFromStrings(expansion, valueType):
return display
except Exception:
pass
# Fall back to string representation
return str(value)
# -----------------------------------------------------------------------------
# Markdown Table Parsing
# -----------------------------------------------------------------------------
def parseMarkdownTable(lines, startIdx):
'''Parse a markdown table. Returns (rows, columnMismatchCount, endIndex).'''
table = []
headers = []
columnMismatchCount = 0
idx = startIdx
# Parse header row
if idx < len(lines) and '|' in lines[idx]:
headerLine = lines[idx].strip()
headers = [h.strip().strip('`') for h in headerLine.split('|')[1:-1]]
idx += 1
else:
return [], 0, startIdx
# Skip separator row
if idx < len(lines) and '|' in lines[idx] and '-' in lines[idx]:
idx += 1
else:
return [], 0, startIdx
# Parse data rows
while idx < len(lines):
line = lines[idx].strip()
if not line or not line.startswith('|'):
break
cells = [c.strip().strip('`') for c in line.split('|')[1:-1]]
if len(cells) == len(headers):
row = {headers[i].lower(): cells[i] for i in range(len(headers))}
table.append(row)
else:
columnMismatchCount += 1
idx += 1
return table, columnMismatchCount, idx
# -----------------------------------------------------------------------------
# Specification Document Parsing
# -----------------------------------------------------------------------------
def isValidTypeGroupAssignment(driverNames, combo, typeGroupVariables):
'''
Check if type assignments satisfy group constraints (e.g., colorN ports must
match, colorM must differ from colorN). Returns (isValid, typeAssignment).
'''
typeAssignment = {}
groupAssignments = {} # groupName -> concreteType assigned to that group
for name, (concreteType, groupName) in zip(driverNames, combo):
typeAssignment[name] = concreteType
# Skip constraint checking for None types (these will be resolved via typeRef)
if concreteType is None:
continue
if not groupName:
continue
# For group variables (colorM), get the base group (colorN)
baseGroup = typeGroupVariables.get(groupName, groupName)
isVariable = groupName in typeGroupVariables
# Check consistency: all uses of the same group must have same concrete type
if groupName in groupAssignments:
if groupAssignments[groupName] != concreteType:
return False, None
else:
groupAssignments[groupName] = concreteType
# For variables: must differ from base group if base is already assigned
if isVariable and baseGroup in groupAssignments:
if groupAssignments[baseGroup] == concreteType:
return False, None
return True, typeAssignment
def expandSpecSignatures(inputs, outputs, typeGroups, typeGroupVariables):
'''
Expand spec port definitions into concrete NodeSignatures.
Handles type groups, type group variables, and "Same as X or Y" patterns.
'''
allPorts = {**inputs, **outputs}
# Identify driver ports and their type options
# - Ports with explicit types (no typeRef): use those types
# - Ports with both types AND typeRef ("Same as X or Y"): explicit types OR inherit from typeRef
drivers = {}
for name, port in allPorts.items():
if port.types and not port.typeRef:
# Normal driver: explicit types only
drivers[name] = expandTypeSet(port.types, typeGroups, typeGroupVariables)
elif port.types and port.typeRef:
# "Same as X or Y" pattern: explicit types OR inherit from typeRef
expanded = expandTypeSet(port.types, typeGroups, typeGroupVariables)
expanded.append((None, None)) # None means "inherit from typeRef"
drivers[name] = expanded
if not drivers:
return set()
# Generate all combinations of driver types
driverNames = sorted(drivers.keys())
driverTypeLists = [drivers[n] for n in driverNames]
signatures = set()
for combo in product(*driverTypeLists):
# Validate type group constraints (skip None values which will be resolved via typeRef)
valid, typeAssignment = isValidTypeGroupAssignment(driverNames, combo, typeGroupVariables)
if not valid:
continue
# Remove None assignments - these ports will be resolved via typeRef
typeAssignment = {k: v for k, v in typeAssignment.items() if v is not None}
# Resolve typeRefs for this combination
resolved = resolveTypeAssignment(typeAssignment, allPorts)
if resolved is None:
continue
# Build signature
sigInputs = {name: resolved[name] for name in inputs if name in resolved}
sigOutputs = {name: resolved[name] for name in outputs if name in resolved}
signatures.add(NodeSignature.create(sigInputs, sigOutputs))
return signatures
def resolveTypeAssignment(baseAssignment, allPorts):
'''Resolve "Same as X" references to complete port type assignments.'''
assignment = baseAssignment.copy()
# Iteratively resolve references (limit iterations to handle circular refs)
for _ in range(10):
changed = False
for name, port in allPorts.items():
if name in assignment:
continue
if port.typeRef and port.typeRef in assignment:
assignment[name] = assignment[port.typeRef]
changed = True
if not changed:
break
# Check all ports resolved
if set(assignment.keys()) != set(allPorts.keys()):
return None
return assignment
def resolvePortTypeRefs(ports):
'''Resolve type references between ports by copying types. Modifies ports in place.'''
# Limit iterations to handle circular refs
for _ in range(10):
changed = False
for port in ports.values():
if port.typeRef:
refPort = ports.get(port.typeRef)
if refPort and refPort.types:
port.types.update(refPort.types)
port.typeRef = None
changed = True
if not changed:
break
def parseSpecDocument(specPath, stdlib, geompropNames):
'''Parse a specification markdown document. Returns (nodes, invalidEntries).'''
# Build type system data from stdlib
standardTypes = getStandardTypes(stdlib)
typeGroups = buildTypeGroups(stdlib)
typeGroupVariables = buildTypeGroupVariables(typeGroups)
# Build derived values for validation and parsing
knownTypes = standardTypes | set(typeGroups.keys()) | set(typeGroupVariables.keys())
specDefaultNotation = {
'__zero__': '0',
'__one__': '1',
'__half__': '0.5',
'__empty__': '',
}
for name in geompropNames:
specDefaultNotation[f'_{name}_'] = name
nodes = {}
invalidEntries = []
with open(specPath, 'r', encoding='utf-8') as f:
content = f.read()
lines = content.split('\n')
currentNode = None
currentTableInputs = {}
currentTableOutputs = {}
idx = 0
def finalizeCurrentTable():
'''Expand current table to signatures and add to node.'''
nonlocal currentTableInputs, currentTableOutputs
if currentNode and (currentTableInputs or currentTableOutputs):
node = nodes[currentNode]
# Expand to signatures (do NOT pre-resolve typeRefs - expansion handles them)
tableSigs = expandSpecSignatures(currentTableInputs, currentTableOutputs, typeGroups, typeGroupVariables)
node.signatures.update(tableSigs)
# Merge input port info for default comparison (resolve types for defaults)
allPorts = {**currentTableInputs, **currentTableOutputs}
resolvePortTypeRefs(allPorts)
for name, port in currentTableInputs.items():
if name not in node._specInputs:
node._specInputs[name] = port
else:
node._specInputs[name].types.update(port.types)
currentTableInputs = {}
currentTableOutputs = {}
while idx < len(lines):
line = lines[idx]
# Look for node headers (### `nodename`)
nodeMatch = re.match(r'^###\s+`([^`]+)`', line)
if nodeMatch:
# Finalize previous table before switching nodes
finalizeCurrentTable()
currentNode = nodeMatch.group(1)
if currentNode not in nodes:
nodes[currentNode] = NodeInfo(name=currentNode)
idx += 1
continue
# Look for tables when we have a current node
if currentNode and '|' in line and 'Port' in line:
# Finalize previous table before starting new one
finalizeCurrentTable()
rows, columnMismatchCount, idx = parseMarkdownTable(lines, idx)
# Track column count mismatches
for _ in range(columnMismatchCount):
invalidEntries.append(Difference(
diffType=DiffType.SPEC_COLUMN_MISMATCH,
node=currentNode,
))
if rows:
for row in rows:
portName = row.get('port', '').strip('`*')
# Track empty port names
if not portName:
invalidEntries.append(Difference(
diffType=DiffType.SPEC_EMPTY_PORT_NAME,
node=currentNode,
))
continue
portType = row.get('type', '')
portDefault = row.get('default', '')
portDesc = row.get('description', '')
types, typeRef = parseSpecTypes(portType)
# Track unrecognized types
if types - knownTypes:
invalidEntries.append(Difference(
diffType=DiffType.SPEC_UNRECOGNIZED_TYPE,
node=currentNode,
port=portName,
))
# Determine if this is an output port
isOutput = portName == 'out' or portDesc.lower().startswith('output')
target = currentTableOutputs if isOutput else currentTableInputs
# Create port info for this table
portInfo = target.setdefault(portName, PortInfo(
name=portName,
default=parseSpecDefault(portDefault, specDefaultNotation),
))
portInfo.types.update(types)
if typeRef and not portInfo.typeRef:
portInfo.typeRef = typeRef
continue
idx += 1
# Finalize the last table
finalizeCurrentTable()
return nodes, invalidEntries
# -----------------------------------------------------------------------------
# Data Library Loading
# -----------------------------------------------------------------------------
def loadDataLibrary(mtlxPath):
'''Load a data library MTLX document. Returns (nodes, defaults).'''
doc = mx.createDocument()
mx.readFromXmlFile(doc, str(mtlxPath))
nodes = {}
defaults = {} # (nodeName, signature) -> {portName -> (value, isGeomprop)}
for nodedef in doc.getNodeDefs():
nodeName = nodedef.getNodeString()
node = nodes.setdefault(nodeName, NodeInfo(name=nodeName))
# Build signature from this nodedef
sigInputs = {inp.getName(): inp.getType() for inp in nodedef.getInputs()}
sigOutputs = {out.getName(): out.getType() for out in nodedef.getOutputs()}
sig = NodeSignature.create(sigInputs, sigOutputs)
node.signatures.add(sig)
# Store defaults keyed by signature
sigDefaults = {}
for inp in nodedef.getInputs():
if inp.hasDefaultGeomPropString():
sigDefaults[inp.getName()] = (inp.getDefaultGeomPropString(), True)
elif inp.getValue() is not None:
sigDefaults[inp.getName()] = (inp.getValue(), False)
if sigDefaults:
defaults[(nodeName, sig)] = sigDefaults
return nodes, defaults
# -----------------------------------------------------------------------------
# Comparison Logic
# -----------------------------------------------------------------------------
def compareSignatureDefaults(nodeName, signature, specNode, libDefaults, geompropNames):
'''Compare default values for a matching signature. Returns list of Differences.'''
differences = []
for portName, valueType in signature.inputs:
specPort = specNode._specInputs.get(portName)
if not specPort or not specPort.default:
continue
specValue, specIsGeomprop = expandSpecDefaultToValue(specPort.default, valueType, geompropNames)
libValue, libIsGeomprop = libDefaults.get(portName, (None, False))
# Skip if either value is unavailable
if specValue is None or libValue is None:
continue
# Compare values (geomprops compare as strings, typed values use equality)
valuesMatch = (specValue == libValue) if (specIsGeomprop == libIsGeomprop) else False
if not valuesMatch:
differences.append(Difference(
diffType=DiffType.DEFAULT_MISMATCH,
node=nodeName,
port=portName,
signature=signature,
valueType=valueType,
specDefault=formatDefaultValue(specValue, valueType, geompropNames),
libDefault=formatDefaultValue(libValue, valueType, geompropNames),
))
return differences
def findLibraryMatch(specSig, libSigs):
'''Find a matching library signature. Returns (matchType, libSig, extraInLib, extraInSpec).'''
specInputs = set(specSig.inputs)
specOutputs = set(specSig.outputs)
for libSig in libSigs:
libInputs = set(libSig.inputs)
libOutputs = set(libSig.outputs)
# Check for exact match
if specInputs == libInputs and specOutputs == libOutputs:
return MatchType.EXACT, libSig, None, None
# Check for different input sets (same outputs, different inputs)
if specOutputs == libOutputs and specInputs != libInputs:
# If there are common inputs, verify they have the same types
commonInputNames = {name for name, _ in specInputs} & {name for name, _ in libInputs}
if commonInputNames:
specInputDict = dict(specSig.inputs)
libInputDict = dict(libSig.inputs)
typesMatch = all(specInputDict[n] == libInputDict[n] for n in commonInputNames)
if not typesMatch:
continue # Common inputs have different types - not a match
extraInLib = tuple(sorted(libInputs - specInputs))
extraInSpec = tuple(sorted(specInputs - libInputs))
return MatchType.DIFFERENT_INPUTS, libSig, extraInLib, extraInSpec
return None, None, None, None
def compareNodes(specNodes, libNodes, libDefaults, geompropNames, compareDefaults=False):
'''Compare nodes between spec and library. Returns list of Differences.'''
differences = []
specNames = set(specNodes.keys())
libNames = set(libNodes.keys())
# Nodes in spec but not in library
for nodeName in sorted(specNames - libNames):
differences.append(Difference(
diffType=DiffType.NODE_MISSING_IN_LIBRARY,
node=nodeName))
# Nodes in library but not in spec
for nodeName in sorted(libNames - specNames):
differences.append(Difference(
diffType=DiffType.NODE_MISSING_IN_SPEC,
node=nodeName))
# Compare signatures for common nodes
for nodeName in sorted(specNames & libNames):
specNode = specNodes[nodeName]
libNode = libNodes[nodeName]
specSigs = specNode.signatures
libSigs = libNode.signatures
# Track which signatures have been matched
matchedLibSigs = set()
matchedSpecSigs = set()
inputDiffMatches = [] # (specSig, libSig, extraInLib, extraInSpec)
# For each spec signature, find matching library signature
for specSig in specSigs:
matchType, libSig, extraInLib, extraInSpec = findLibraryMatch(specSig, libSigs)
if matchType == MatchType.EXACT:
matchedLibSigs.add(libSig)
matchedSpecSigs.add(specSig)
# Compare defaults for exact matches
if compareDefaults:
sigDefaults = libDefaults.get((nodeName, libSig), {})
differences.extend(compareSignatureDefaults(
nodeName, specSig, specNode, sigDefaults, geompropNames))
elif matchType == MatchType.DIFFERENT_INPUTS:
matchedLibSigs.add(libSig)
matchedSpecSigs.add(specSig)
inputDiffMatches.append((specSig, libSig, extraInLib, extraInSpec))
# Compare defaults for different input matches too (for common ports)
if compareDefaults:
sigDefaults = libDefaults.get((nodeName, libSig), {})
differences.extend(compareSignatureDefaults(
nodeName, specSig, specNode, sigDefaults, geompropNames))
# Report different input set matches
for specSig, libSig, extraInLib, extraInSpec in sorted(inputDiffMatches, key=lambda x: str(x[0])):
differences.append(Difference(
diffType=DiffType.SIGNATURE_DIFFERENT_INPUTS,
node=nodeName,
signature=specSig,
extraInLib=extraInLib,
extraInSpec=extraInSpec,
))
# Spec signatures not matched by any library signature
for specSig in sorted(specSigs - matchedSpecSigs, key=str):
differences.append(Difference(
diffType=DiffType.SIGNATURE_MISSING_IN_LIBRARY,
node=nodeName,
signature=specSig,
))
# Library signatures not matched by any spec signature
for libSig in sorted(libSigs - matchedLibSigs, key=str):
differences.append(Difference(
diffType=DiffType.SIGNATURE_MISSING_IN_SPEC,
node=nodeName,
signature=libSig,
))
return differences
# -----------------------------------------------------------------------------
# Output Formatting
# -----------------------------------------------------------------------------
def printDifferences(differences):
'''Print the differences in a formatted way.'''
if not differences:
print("No differences found between specification and data library.")
return
# Group differences by type
byType = {}
for diff in differences:
byType.setdefault(diff.diffType, []).append(diff)
print(f"\n{'=' * 70}")
print(f"COMPARISON RESULTS: {len(differences)} difference(s) found")
print(f"{'=' * 70}")
for diffType in DiffType:
if diffType not in byType:
continue
diffs = byType[diffType]
print(f"\n{diffType.value} ({len(diffs)}):")
print("-" * 50)
for diff in diffs:
for line in formatDifference(diff):
print(line)
# -----------------------------------------------------------------------------
# Main Entry Point
# -----------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Compare node definitions between a specification Markdown document and a data library MaterialX document.")
parser.add_argument('--spec', dest='specFile',
help='Path to the specification Markdown document. Defaults to documents/Specification/MaterialX.StandardNodes.md')
parser.add_argument('--mtlx', dest='mtlxFile',
help='Path to the data library MaterialX document. Defaults to libraries/stdlib/stdlib_defs.mtlx')
parser.add_argument('--defaults', dest='compareDefaults', action='store_true',
help='Compare default values for inputs using MaterialX typed value comparison')
parser.add_argument('--listNodes', dest='listNodes', action='store_true',
help='List all nodes and their node signature counts')
opts = parser.parse_args()
# Determine file paths
repoRoot = Path(__file__).resolve().parent.parent.parent
specPath = Path(opts.specFile) if opts.specFile else repoRoot / 'documents' / 'Specification' / 'MaterialX.StandardNodes.md'
mtlxPath = Path(opts.mtlxFile) if opts.mtlxFile else repoRoot / 'libraries' / 'stdlib' / 'stdlib_defs.mtlx'
# Verify files exist
if not specPath.exists():
raise FileNotFoundError(f"Specification document not found: {specPath}")
if not mtlxPath.exists():
raise FileNotFoundError(f"MTLX document not found: {mtlxPath}")
print(f"Comparing:")
print(f" Specification: {specPath}")
print(f" Data Library: {mtlxPath}")
# Load standard libraries
stdlib = loadStandardLibraries()
geompropNames = buildGeompropNames(stdlib)
# Parse specification
print("\nParsing specification...")
specNodes, invalidEntries = parseSpecDocument(specPath, stdlib, geompropNames)
specSigCount = sum(len(n.signatures) for n in specNodes.values())
print(f" Found {len(specNodes)} nodes with {specSigCount} node signatures")
if invalidEntries:
print(f" Found {len(invalidEntries)} invalid specification entries")
# Load data library
print("Loading data library...")
libNodes, libDefaults = loadDataLibrary(mtlxPath)
libSigCount = sum(len(n.signatures) for n in libNodes.values())
print(f" Found {len(libNodes)} nodes with {libSigCount} node signatures")
# List nodes if requested
if opts.listNodes:
print("\nNodes in Specification:")
for name in sorted(specNodes.keys()):
node = specNodes[name]
print(f" {name}: {len(node.signatures)} signature(s)")
print("\nNodes in Data Library:")
for name in sorted(libNodes.keys()):
node = libNodes[name]
print(f" {name}: {len(node.signatures)} signature(s)")
# Compare nodes
print("\nComparing node signatures...")
differences = compareNodes(specNodes, libNodes, libDefaults, geompropNames, opts.compareDefaults)
# Include invalid spec entries in the differences
differences = invalidEntries + differences
# Print differences
printDifferences(differences)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,268 @@
#!/usr/bin/env python
'''
Construct a MaterialX file from the textures in the given folder, using the standard data libraries
to build a mapping from texture filenames to shader inputs.
By default the standard_surface shading model is assumed, with the --shadingModel option used to
select any other shading model in the data libraries.
'''
import os
import re
import argparse
from difflib import SequenceMatcher
import MaterialX as mx
UDIM_TOKEN = '.<UDIM>.'
UDIM_REGEX = r'\.\d+\.'
TEXTURE_EXTENSIONS = [ "exr", "png", "jpg", "jpeg", "tif", "hdr" ]
INPUT_ALIASES = { "roughness": "specular_roughness" }
class UdimFilePath(mx.FilePath):
def __init__(self, pathString):
super().__init__(pathString)
self._isUdim = False
self._udimFiles = []
self._udimRegex = re.compile(UDIM_REGEX)
textureDir = self.getParentPath()
textureName = self.getBaseName()
textureExtension = self.getExtension()
if not self._udimRegex.search(textureName):
self._udimFiles = [self]
return
self._isUdim = True
fullNamePattern = self._udimRegex.sub(self._udimRegex.pattern.replace('\\', '\\\\'),
textureName)
udimFiles = filter(
lambda f: re.search(fullNamePattern, f.asString()),
textureDir.getFilesInDirectory(textureExtension)
)
self._udimFiles = [textureDir / f for f in udimFiles]
def __str__(self):
return self.asPattern()
def asPattern(self):
if not self._isUdim:
return self.asString()
textureDir = self.getParentPath()
textureName = self.getBaseName()
pattern = textureDir / mx.FilePath(
self._udimRegex.sub(UDIM_TOKEN, textureName))
return pattern.asString()
def isUdim(self):
return self._isUdim
def getUdimFiles(self):
return self._udimFiles
def getUdimNumbers(self):
def _extractUdimNumber(_file):
pattern = self._udimRegex.search(_file.getBaseName())
if pattern:
return re.search(r"\d+", pattern.group()).group()
return list(map(_extractUdimNumber, self._udimFiles))
def getNameWithoutExtension(self):
if self._isUdim:
name = self._udimRegex.split(self.getBaseName())[0]
else:
name = self.getBaseName().rsplit('.', 1)[0]
return re.sub(r'[^\w\s]+', '_', name)
def listTextures(textureDir, texturePrefix=None):
'''
Return a list of texture filenames matching known extensions.
'''
texturePrefix = texturePrefix or ""
allTextures = []
for ext in TEXTURE_EXTENSIONS:
textures = [textureDir / f for f in textureDir.getFilesInDirectory(ext)
if f.asString().lower().startswith(texturePrefix.lower())]
while textures:
textureFile = UdimFilePath(textures[0].asString())
allTextures.append(textureFile)
for udimFile in textureFile.getUdimFiles():
textures.remove(udimFile)
return allTextures
def findBestMatch(textureName, shadingModel):
'''
Given a texture name and shading model, return the shader input that is the closest match.
'''
parts = textureName.rsplit("_")
baseTexName = parts[-1]
if baseTexName.lower() == 'color':
baseTexName = ''.join(parts[-2:])
if baseTexName in INPUT_ALIASES:
baseTexName = INPUT_ALIASES.get(baseTexName.lower())
shaderInputs = shadingModel.getActiveInputs()
ratios = []
for shaderInput in shaderInputs:
inputName = shaderInput.getName()
inputName = re.sub(r'[^a-zA-Z0-9\s]', '', inputName).lower()
baseTexName = re.sub(r'[^a-zA-Z0-9\s]', '', baseTexName).lower()
sequenceScore = SequenceMatcher(None, inputName, baseTexName).ratio()
ratios.append(sequenceScore * 100)
highscore = max(ratios)
if highscore < 50:
return None
idx = ratios.index(highscore)
return shaderInputs[idx]
def buildDocument(textureFiles, mtlxFile, shadingModel, colorspace, useTiledImage):
'''
Build a MaterialX document from the given textures and shading model.
'''
# Find the default library nodedef, if any, for the requested shading model.
stdlib = mx.createDocument()
mx.loadLibraries(mx.getDefaultDataLibraryFolders(), mx.getDefaultDataSearchPath(), stdlib)
matchingNodeDefs = stdlib.getMatchingNodeDefs(shadingModel)
if not matchingNodeDefs:
print('Shading model', shadingModel, 'not found in the MaterialX data libraries')
return None
shadingModelNodeDef = matchingNodeDefs[0]
for nodeDef in matchingNodeDefs:
if nodeDef.getAttribute('isdefaultversion') == 'true':
shadingModelNodeDef = nodeDef
# Create content document.
doc = mx.createDocument()
materialName = mx.createValidName(mtlxFile.getBaseName().rsplit('.', 1)[0])
nodeGraph = doc.addNodeGraph('NG_' + materialName)
shaderNode = doc.addNode(shadingModel, 'SR_' + materialName, 'surfaceshader')
doc.addMaterialNode('M_' + materialName, shaderNode)
# Iterate over texture files.
imageNodeCategory = 'tiledimage' if useTiledImage else 'image'
udimNumbers = set()
for textureFile in textureFiles:
textureName = textureFile.getNameWithoutExtension()
shaderInput = findBestMatch(textureName, shadingModelNodeDef)
if not shaderInput:
print('Skipping', textureFile.getBaseName(), 'which does not match any', shadingModel, 'input')
continue
inputName = shaderInput.getName()
inputType = shaderInput.getType()
# Skip inputs that have already been created, e.g. in multi-UDIM materials.
if shaderNode.getInput(inputName) or nodeGraph.getChild(textureName):
continue
mtlInput = shaderNode.addInput(inputName)
textureName = nodeGraph.createValidChildName(textureName)
imageNode = nodeGraph.addNode(imageNodeCategory, textureName, inputType)
# Set color space.
if shaderInput.isColorType():
imageNode.setColorSpace(colorspace)
# Set file path.
filePathString = os.path.relpath(textureFile.asPattern(), mtlxFile.getParentPath().asString())
imageNode.setInputValue('file', filePathString, 'filename')
# Apply special cases for normal maps.
inputNode = imageNode
connNode = imageNode
inBetweenNodes = []
if inputName.endswith('normal') and shadingModel == 'standard_surface':
inBetweenNodes = ["normalmap"]
for inNodeName in inBetweenNodes:
connNode = nodeGraph.addNode(inNodeName, textureName + '_' + inNodeName, inputType)
connNode.setConnectedNode('in', inputNode)
inputNode = connNode
# Create output.
outputNode = nodeGraph.addOutput(textureName + '_output', inputType)
outputNode.setConnectedNode(connNode)
mtlInput.setConnectedOutput(outputNode)
mtlInput.setType(inputType)
if textureFile.isUdim():
udimNumbers.update(set(textureFile.getUdimNumbers()))
# Create udim set
if udimNumbers:
geomInfoName = doc.createValidChildName('GI_' + materialName)
geomInfo = doc.addGeomInfo(geomInfoName)
geomInfo.setGeomPropValue('udimset', list(udimNumbers), "stringarray")
# Return the new document
return doc
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--outputFilename', dest='outputFilename', type=str, help='Filename of the output MaterialX document.')
parser.add_argument('--shadingModel', dest='shadingModel', type=str, default="standard_surface", help='The shading model used in analyzing input textures.')
parser.add_argument('--colorSpace', dest='colorSpace', type=str, help='The colorspace in which input textures should be interpreted, defaulting to srgb_texture.')
parser.add_argument('--texturePrefix', dest='texturePrefix', type=str, help='Filter input textures by the given prefix.')
parser.add_argument('--tiledImage', dest='tiledImage', action="store_true", help='Request tiledimage nodes instead of image nodes.')
parser.add_argument(dest='inputDirectory', nargs='?', help='Input folder that will be scanned for textures, defaulting to the current working directory.')
options = parser.parse_args()
texturePath = mx.FilePath.getCurrentPath()
if options.inputDirectory:
texturePath = mx.FilePath(options.inputDirectory)
if not texturePath.isDirectory():
print('Input folder not found:', texturePath)
return
mtlxFile = texturePath / mx.FilePath('material.mtlx')
if options.outputFilename:
mtlxFile = mx.FilePath(options.outputFilename)
textureFiles = listTextures(texturePath, texturePrefix=options.texturePrefix)
if not textureFiles:
print('No matching textures found in input folder.')
return
# Get shading model and color space.
shadingModel = 'standard_surface'
colorspace = 'srgb_texture'
if options.shadingModel:
shadingModel = options.shadingModel
if options.colorSpace:
colorspace = options.colorSpace
print('Analyzing textures in the', texturePath.asString(), 'folder for the', shadingModel, 'shading model.')
# Create the MaterialX document.
doc = buildDocument(textureFiles, mtlxFile, shadingModel, colorspace, options.tiledImage)
if not doc:
return
if options.outputFilename:
# Write the document to disk.
if not mtlxFile.getParentPath().exists():
mtlxFile.getParentPath().createDirectory()
mx.writeToXmlFile(doc, mtlxFile.asString())
print('Wrote MaterialX document to disk:', mtlxFile.asString())
else:
# Print the document to the standard output.
print('Generated MaterialX document:')
print(mx.writeToXmlString(doc))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,204 @@
#!/usr/bin/env python
'''
Generate shader code for each renderable element in a MaterialX document or folder.
The currently supported target languages are GLSL, ESSL, MSL, OSL, and MDL.
'''
import sys, os, argparse, subprocess
import MaterialX as mx
import MaterialX.PyMaterialXGenGlsl as mx_gen_glsl
import MaterialX.PyMaterialXGenMdl as mx_gen_mdl
import MaterialX.PyMaterialXGenMsl as mx_gen_msl
import MaterialX.PyMaterialXGenOsl as mx_gen_osl
import MaterialX.PyMaterialXGenSlang as mx_gen_slang
import MaterialX.PyMaterialXGenShader as mx_gen_shader
def validateCode(sourceCodeFile, codevalidator, codevalidatorArgs):
if codevalidator:
cmd = codevalidator.split()
cmd.append(sourceCodeFile)
if codevalidatorArgs:
cmd.append(codevalidatorArgs)
cmd_flatten ='----- Run Validator: '
for c in cmd:
cmd_flatten += c + ' '
print(cmd_flatten)
try:
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
return output.decode(encoding='utf-8')
except subprocess.CalledProcessError as out:
return (out.output.decode(encoding='utf-8'))
return ""
def getMaterialXFiles(rootPath):
filelist = []
if os.path.isdir(rootPath):
for subdir, dirs, files in os.walk(rootPath):
for file in files:
if file.endswith('mtlx'):
filelist.append(os.path.join(subdir, file))
else:
filelist.append( rootPath )
return filelist
def main():
parser = argparse.ArgumentParser(description='Generate shader code for each renderable element in a MaterialX document or folder.')
parser.add_argument('--path', dest='paths', action='append', nargs='+', help='An additional absolute search path location (e.g. "/projects/MaterialX")')
parser.add_argument('--library', dest='libraries', action='append', nargs='+', help='An additional relative path to a custom data library folder (e.g. "libraries/custom")')
parser.add_argument('--target', dest='target', default='glsl', help='Target shader generator to use (e.g. "glsl, osl, mdl, essl, vulkan, wgsl"). Default is glsl.')
parser.add_argument('--outputPath', dest='outputPath', help='File path to output shaders to. If not specified, is the location of the input document is used.')
parser.add_argument('--validator', dest='validator', nargs='?', const=' ', type=str, help='Name of executable to perform source code validation.')
parser.add_argument('--validatorArgs', dest='validatorArgs', nargs='?', const=' ', type=str, help='Optional arguments for code validator.')
parser.add_argument('--vulkanGlsl', dest='vulkanCompliantGlsl', default=False, type=bool, help='Set to True to generate Vulkan-compliant GLSL when using the genglsl target.')
parser.add_argument('--shaderInterfaceType', dest='shaderInterfaceType', default=0, type=int, help='Set the type of shader interface to be generated')
parser.add_argument(dest='inputFilename', help='Path to input document or folder containing input documents.')
opts = parser.parse_args()
# Load standard and custom data libraries.
stdlib = mx.createDocument()
searchPath = mx.getDefaultDataSearchPath()
libraryFolders = []
if opts.paths:
for pathList in opts.paths:
for path in pathList:
searchPath.append(path)
if opts.libraries:
for libraryList in opts.libraries:
for library in libraryList:
libraryFolders.append(library)
libraryFolders.extend(mx.getDefaultDataLibraryFolders())
mx.loadLibraries(libraryFolders, searchPath, stdlib)
# Generate shaders for each input document.
for inputFilename in getMaterialXFiles(opts.inputFilename):
doc = mx.createDocument()
try:
mx.readFromXmlFile(doc, inputFilename)
doc.setDataLibrary(stdlib)
except mx.ExceptionFileMissing as err:
print('Generation failed: "', err, '"')
sys.exit(-1)
print('---------- Generate code for file: ', inputFilename, '--------------------')
valid, msg = doc.validate()
if not valid:
print('Validation warnings for input document:')
print(msg)
gentarget = 'glsl'
if opts.target:
gentarget = opts.target
if gentarget == 'osl':
shadergen = mx_gen_osl.OslShaderGenerator.create()
elif gentarget == 'mdl':
shadergen = mx_gen_mdl.MdlShaderGenerator.create()
elif gentarget == 'essl':
shadergen = mx_gen_glsl.EsslShaderGenerator.create()
elif gentarget == 'vulkan':
shadergen = mx_gen_glsl.VkShaderGenerator.create()
elif gentarget == 'wgsl':
shadergen = mx_gen_glsl.WgslShaderGenerator.create()
elif gentarget == 'msl':
shadergen = mx_gen_msl.MslShaderGenerator.create()
elif gentarget == 'slang':
shadergen = mx_gen_slang.SlangShaderGenerator.create()
else:
shadergen = mx_gen_glsl.GlslShaderGenerator.create()
codeSearchPath = searchPath
codeSearchPath.append(os.path.dirname(inputFilename))
context = mx_gen_shader.GenContext(shadergen)
context.registerSourceCodeSearchPath(codeSearchPath)
shadergen.registerTypeDefs(doc);
# If we're generating Vulkan-compliant GLSL then set the binding context
if opts.vulkanCompliantGlsl:
bindingContext = mx_gen_glsl.GlslResourceBindingContext.create(0,0)
context.pushUserData('udbinding', bindingContext)
genoptions = context.getOptions()
if opts.shaderInterfaceType == 0 or opts.shaderInterfaceType == 1:
genoptions.shaderInterfaceType = mx_gen_shader.ShaderInterfaceType(opts.shaderInterfaceType)
else:
genoptions.shaderInterfaceType = mx_gen_shader.ShaderInterfaceType.SHADER_INTERFACE_COMPLETE
print('- Set up CMS ...')
cms = mx_gen_shader.DefaultColorManagementSystem.create(shadergen.getTarget())
cms.loadLibrary(doc)
shadergen.setColorManagementSystem(cms)
print('- Set up Units ...')
unitsystem = mx_gen_shader.UnitSystem.create(shadergen.getTarget())
registry = mx.UnitConverterRegistry.create()
distanceTypeDef = doc.getUnitTypeDef('distance')
registry.addUnitConverter(distanceTypeDef, mx.LinearUnitConverter.create(distanceTypeDef))
angleTypeDef = doc.getUnitTypeDef('angle')
registry.addUnitConverter(angleTypeDef, mx.LinearUnitConverter.create(angleTypeDef))
unitsystem.loadLibrary(stdlib)
unitsystem.setUnitConverterRegistry(registry)
shadergen.setUnitSystem(unitsystem)
genoptions.targetDistanceUnit = 'meter'
pathPrefix = ''
if opts.outputPath and os.path.exists(opts.outputPath):
pathPrefix = opts.outputPath + os.path.sep
else:
pathPrefix = os.path.dirname(os.path.abspath(inputFilename))
print('- Shader output path: ' + pathPrefix)
failedShaders = ""
for elem in mx_gen_shader.findRenderableElements(doc):
elemName = elem.getName()
print('-- Generate code for element: ' + elemName)
elemName = mx.createValidName(elemName)
shader = shadergen.generate(elemName, elem, context)
if shader:
# Use extension of .vert and .frag as it's type is
# recognized by glslangValidator
if gentarget in ['glsl', 'essl', 'vulkan', 'msl', 'wgsl']:
pixelSource = shader.getSourceCode(mx_gen_shader.PIXEL_STAGE)
filename = pathPrefix + "/" + shader.getName() + "." + gentarget + ".frag"
print('--- Wrote pixel shader to: ' + filename)
file = open(filename, 'w+')
file.write(pixelSource)
file.close()
errors = validateCode(filename, opts.validator, opts.validatorArgs)
vertexSource = shader.getSourceCode(mx_gen_shader.VERTEX_STAGE)
filename = pathPrefix + "/" + shader.getName() + "." + gentarget + ".vert"
print('--- Wrote vertex shader to: ' + filename)
file = open(filename, 'w+')
file.write(vertexSource)
file.close()
errors += validateCode(filename, opts.validator, opts.validatorArgs)
else:
pixelSource = shader.getSourceCode(mx_gen_shader.PIXEL_STAGE)
filename = pathPrefix + "/" + shader.getName() + "." + gentarget
print('--- Wrote pixel shader to: ' + filename)
file = open(filename, 'w+')
file.write(pixelSource)
file.close()
errors = validateCode(filename, opts.validator, opts.validatorArgs)
if errors != "":
print("--- Validation failed for element: ", elemName)
print("----------------------------")
print('--- Error log: ', errors)
print("----------------------------")
failedShaders += (elemName + ' ')
else:
print("--- Validation passed for element:", elemName)
else:
print("--- Validation failed for element:", elemName)
failedShaders += (elemName + ' ')
if failedShaders != "":
sys.exit(-1)
if __name__ == '__main__':
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python
'''
Print markdown documentation for each nodedef in the given document.
'''
import argparse
import sys
import MaterialX as mx
HEADERS = ('Name', 'Type', 'Default Value',
'UI name', 'UI min', 'UI max', 'UI Soft Min', 'UI Soft Max', 'UI step', 'UI group', 'UI Advanced', 'Doc', 'Uniform')
ATTR_NAMES = ('uiname', 'uimin', 'uimax', 'uisoftmin', 'uisoftmax', 'uistep', 'uifolder', 'uiadvanced', 'doc', 'uniform' )
def main():
parser = argparse.ArgumentParser(description="Print documentation for each nodedef in the given document.")
parser.add_argument(dest="inputFilename", help="Filename of the input MaterialX document.")
parser.add_argument('--docType', dest='documentType', default='md', help='Document type. Default is "md" (Markdown). Specify "html" for HTML output')
parser.add_argument('--showInherited', default=False, action='store_true', help='Show inherited inputs. Default is False')
opts = parser.parse_args()
doc = mx.createDocument()
try:
mx.readFromXmlFile(doc, opts.inputFilename)
except mx.ExceptionFileMissing as err:
print(err)
sys.exit(0)
for nd in doc.getNodeDefs():
# HTML output
if opts.documentType == "html":
print('<head><style>')
print('table, th, td {')
print(' border-bottom: 1px solid; border-collapse: collapse; padding: 10px;')
print('}')
print('</style></head>')
print('<ul>')
print('<li> <em>Nodedef</em>: %s' % nd.getName())
print('<li> <em>Type</em>: %s' % nd.getType())
if len(nd.getNodeGroup()) > 0:
print('<li> <em>Node Group</em>: %s' % nd.getNodeGroup())
if len(nd.getVersionString()) > 0:
print('<li> <em>Version</em>: %s. Is default: %s' % (nd.getVersionString(), nd.getDefaultVersion()))
if len(nd.getInheritString()) > 0:
print('<li> <em>Inherits From</em>: %s' % nd.getInheritString())
print('<li> <em>Doc</em>: %s\n' % nd.getAttribute('doc'))
print('</ul>')
print('<table><tr>')
for h in HEADERS:
print('<th>' + h + '</th>')
print('</tr>')
inputList = nd.getActiveInputs() if opts.showInherited else nd.getInputs()
tokenList = nd.getActiveTokens() if opts.showInherited else nd.getTokens()
outputList = nd.getActiveOutputs() if opts.showInherited else nd.getOutputs()
totalList = inputList + tokenList + outputList;
for port in totalList:
print('<tr>')
infos = []
if port in outputList:
infos.append('<em>'+ port.getName() + '</em>')
elif port in tokenList:
infos.append(port.getName())
else:
infos.append('<b>'+ port.getName() + '</b>')
infos.append(port.getType())
val = port.getValue()
if port.getType() == "float":
val = round(val, 6)
infos.append(str(val))
for attrname in ATTR_NAMES:
infos.append(port.getAttribute(attrname))
for info in infos:
print('<td>' + info + '</td>')
print('</tr>')
print('</table>')
# Markdown output
else:
print('- *Nodedef*: %s' % nd.getName())
print('- *Type*: %s' % nd.getType())
if len(nd.getNodeGroup()) > 0:
print('- *Node Group*: %s' % nd.getNodeGroup())
if len(nd.getVersionString()) > 0:
print('- *Version*: %s. Is default: %s' % (nd.getVersionString(), nd.getDefaultVersion()))
if len(nd.getInheritString()) > 0:
print('- *Inherits From*: %s' % nd.getInheritString())
print('- *Doc*: %s\n' % nd.getAttribute('doc'))
print('| ' + ' | '.join(HEADERS) + ' |')
print('|' + ' ---- |' * len(HEADERS) + '')
inputList = nd.getActiveInputs() if opts.showInherited else nd.getInputs()
tokenList = nd.getActiveTokens() if opts.showInherited else nd.getTokens()
outputList = nd.getActiveOutputs() if opts.showInherited else nd.getOutputs()
totalList = inputList + tokenList + outputList;
for port in totalList:
infos = []
if port in outputList:
infos.append('*'+ port.getName() + '*')
elif port in tokenList:
infos.append(port.getName())
else:
infos.append('**'+ port.getName() + '**')
infos.append(port.getType())
val = port.getValue()
if port.getType() == "float":
val = round(val, 6)
infos.append(str(val))
for attrname in ATTR_NAMES:
infos.append(port.getAttribute(attrname))
print('| ' + " | ".join(infos) + ' |')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python
'''
Reformat a folder of MaterialX documents in place, optionally upgrading
the documents to the latest version of the standard.
'''
import argparse
import os
import xml.etree.ElementTree as ET
import MaterialX as mx
def is_well_formed(xml_string):
error = ''
try:
ET.fromstring(xml_string)
except ET.ParseError as e:
error = str(e)
def main():
parser = argparse.ArgumentParser(description="Reformat a folder of MaterialX documents in place.")
parser.add_argument('-y', '--yes', dest='yes', action="store_true", help="Proceed without asking for confirmation from the user.")
parser.add_argument('-u', '--upgrade', dest='upgrade', action="store_true", help='Upgrade documents to the latest version of the standard.')
parser.add_argument('-v', '--validate', dest='validate', action="store_true", help='Perform MaterialX validation on documents after reformatting.')
parser.add_argument('-x', '--xml_syntax', dest='xml_syntax', action="store_true", help='Check XML syntax after reformatting.')
parser.add_argument(dest="inputFolder", help="An input folder to scan for MaterialX documents.")
opts = parser.parse_args()
validDocs = dict()
for root, dirs, files in os.walk(opts.inputFolder):
for filename in files:
fullpath = os.path.join(root, filename)
if fullpath.endswith('.mtlx'):
doc = mx.createDocument()
try:
readOptions = mx.XmlReadOptions()
readOptions.readComments = True
readOptions.readNewlines = True
readOptions.upgradeVersion = opts.upgrade
try:
mx.readFromXmlFile(doc, fullpath, mx.FileSearchPath(), readOptions)
except Exception as err:
print('Skipping "' + filename + '" due to exception: ' + str(err))
continue
validDocs[fullpath] = doc
except mx.Exception:
pass
if not validDocs:
print('No MaterialX documents were found in "%s"' % (opts.inputFolder))
return
print('Found %s MaterialX files in "%s"' % (len(validDocs), opts.inputFolder))
mxVersion = mx.getVersionIntegers()
if not opts.yes:
if opts.upgrade:
question = 'Would you like to upgrade all %i documents to MaterialX v%i.%i in place (y/n)?' % (len(validDocs), mxVersion[0], mxVersion[1])
else:
question = 'Would you like to reformat all %i documents in place (y/n)?' % len(validDocs)
answer = input(question)
if answer != 'y' and answer != 'Y':
return
validate = opts.validate
if validate:
print(f'- Validate documents')
xml_syntax = opts.xml_syntax
if xml_syntax:
print(f'- Check XML syntax')
for (filename, doc) in validDocs.items():
if xml_syntax:
xml_string = mx.writeToXmlString(doc)
errors = is_well_formed(xml_string)
if errors:
print(f'- Warning: Document {filename} is not well-formed XML: {errors}')
if validate:
is_valid, errors = doc.validate()
if not is_valid:
print(f'- Warning: Document {filename} is invalid. Errors {errors}.')
mx.writeToXmlFile(doc, filename)
if opts.upgrade:
print('Upgraded %i documents to MaterialX v%i.%i' % (len(validDocs), mxVersion[0], mxVersion[1]))
else:
print('Reformatted %i documents ' % len(validDocs))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,363 @@
#!/usr/bin/env python
'''
Verify that the given file is a valid MaterialX document.
'''
import argparse
import sys
import MaterialX as mx
def main():
parser = argparse.ArgumentParser(description="Verify that the given file is a valid MaterialX document.")
parser.add_argument("--resolve", dest="resolve", action="store_true", help="Resolve inheritance and string substitutions.")
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print summary of elements found in the document.")
parser.add_argument("--stdlib", dest="stdlib", action="store_true", help="Import standard MaterialX libraries into the document.")
parser.add_argument(dest="inputFilename", help="Filename of the input document.")
opts = parser.parse_args()
# Load standard libraries if requested.
stdlib = None
if opts.stdlib:
stdlib = mx.createDocument()
try:
mx.loadLibraries(mx.getDefaultDataLibraryFolders(), mx.getDefaultDataSearchPath(), stdlib)
except Exception as err:
print(err)
sys.exit(0)
# Read and validate the source document.
doc = mx.createDocument()
try:
mx.readFromXmlFile(doc, opts.inputFilename)
if stdlib:
doc.setDataLibrary(stdlib)
except mx.ExceptionFileMissing as err:
print(err)
sys.exit(0)
valid, message = doc.validate()
if (valid):
print("%s is a valid MaterialX document in v%s" % (opts.inputFilename, mx.getVersionString()))
else:
print("%s is not a valid MaterialX document in v%s" % (opts.inputFilename, mx.getVersionString()))
print(message)
# Generate verbose output if requested.
if opts.verbose:
nodegraphs = doc.getNodeGraphs()
materials = doc.getMaterialNodes()
looks = doc.getLooks()
lookgroups = doc.getLookGroups()
collections = doc.getCollections()
nodedefs = doc.getNodeDefs()
implementations = doc.getImplementations()
geominfos = doc.getGeomInfos()
geompropdefs = doc.getGeomPropDefs()
typedefs = doc.getTypeDefs()
propsets = doc.getPropertySets()
variantsets = doc.getVariantSets()
backdrops = doc.getBackdrops()
print("----------------------------------")
print("Document Version: {}.{:02d}".format(*doc.getVersionIntegers()))
print("%4d Custom Type%s%s" % (len(typedefs), pl(typedefs), listContents(typedefs, opts.resolve)))
print("%4d Custom GeomProp%s%s" % (len(geompropdefs), pl(geompropdefs), listContents(geompropdefs, opts.resolve)))
print("%4d NodeDef%s%s" % (len(nodedefs), pl(nodedefs), listContents(nodedefs, opts.resolve)))
print("%4d Implementation%s%s" % (len(implementations), pl(implementations), listContents(implementations, opts.resolve)))
print("%4d Nodegraph%s%s" % (len(nodegraphs), pl(nodegraphs), listContents(nodegraphs, opts.resolve)))
print("%4d VariantSet%s%s" % (len(variantsets), pl(variantsets), listContents(variantsets, opts.resolve)))
print("%4d Material%s%s" % (len(materials), pl(materials), listContents(materials, opts.resolve)))
print("%4d Collection%s%s" % (len(collections), pl(collections), listContents(collections, opts.resolve)))
print("%4d GeomInfo%s%s" % (len(geominfos), pl(geominfos), listContents(geominfos, opts.resolve)))
print("%4d PropertySet%s%s" % (len(propsets), pl(propsets), listContents(propsets, opts.resolve)))
print("%4d Look%s%s" % (len(looks), pl(looks), listContents(looks, opts.resolve)))
print("%4d LookGroup%s%s" % (len(lookgroups), pl(lookgroups), listContents(lookgroups, opts.resolve)))
print("%4d Top-level backdrop%s%s" % (len(backdrops), pl(backdrops), listContents(backdrops, opts.resolve)))
print("----------------------------------")
def listContents(elemlist, resolve):
if len(elemlist) == 0:
return ''
names = []
for elem in elemlist:
if elem.isA(mx.NodeDef):
outtype = elem.getType()
outs = ""
if outtype == "multioutput":
for ot in elem.getOutputs():
outs = outs + \
'\n\t %s output "%s"' % (ot.getType(), ot.getName())
names.append('%s %s "%s"%s' %
(outtype, elem.getNodeString(), elem.getName(), outs))
names.append(listNodedefInterface(elem))
elif elem.isA(mx.Implementation):
impl = elem.getName()
targs = []
if elem.hasTarget():
targs.append("target %s" % elem.getTarget())
if targs:
impl = "%s (%s)" % (impl, ", ".join(targs))
if elem.hasFunction():
if elem.hasFile():
impl = "%s [%s:%s()]" % (
impl, elem.getFile(), elem.getFunction())
else:
impl = "%s [function %s()]" % (impl, elem.getFunction())
elif elem.hasFile():
impl = "%s [%s]" % (impl, elem.getFile())
names.append(impl)
elif elem.isA(mx.Backdrop):
names.append('%s: contains "%s"' %
(elem.getName(), elem.getContainsString()))
elif elem.isA(mx.NodeGraph):
nchildnodes = len(elem.getChildren()) - elem.getOutputCount()
backdrops = elem.getBackdrops()
nbackdrops = len(backdrops)
outs = ""
if nbackdrops > 0:
for bd in backdrops:
outs = outs + '\n\t backdrop "%s"' % (bd.getName())
outs = outs + ' contains "%s"' % bd.getContainsString()
if elem.getOutputCount() > 0:
for ot in elem.getOutputs():
outs = outs + '\n\t %s output "%s"' % (ot.getType(), ot.getName())
outs = outs + traverseInputs(ot, "", 0)
nd = elem.getNodeDef()
if nd:
names.append('%s (implementation for nodedef "%s"): %d nodes%s' % (
elem.getName(), nd.getName(), nchildnodes, outs))
else:
names.append("%s: %d nodes, %d backdrop%s%s" % (
elem.getName(), nchildnodes, nbackdrops, pl(backdrops), outs))
elif elem.isA(mx.Node, mx.SURFACE_MATERIAL_NODE_STRING):
shaders = mx.getShaderNodes(elem)
names.append("%s: %d connected shader node%s" % (elem.getName(), len(shaders), pl(shaders)))
for shader in shaders:
names.append('Shader node "%s" (%s), with bindings:%s' % (shader.getName(), shader.getCategory(), listShaderBindings(shader)))
elif elem.isA(mx.GeomInfo):
props = elem.getGeomProps()
if props:
propnames = " (Geomprops: " + ", ".join(map(
lambda x: "%s=%s" % (x.getName(), getConvertedValue(x)), props)) + ")"
else:
propnames = ""
tokens = elem.getTokens()
if tokens:
tokennames = " (Tokens: " + ", ".join(map(
lambda x: "%s=%s" % (x.getName(), x.getValueString()), tokens)) + ")"
else:
tokennames = ""
names.append("%s%s%s" % (elem.getName(), propnames, tokennames))
elif elem.isA(mx.VariantSet):
vars = elem.getVariants()
if vars:
varnames = " (variants " + ", ".join(map(
lambda x: '"' + x.getName()+'"', vars)) + ")"
else:
varnames = ""
names.append("%s%s" % (elem.getName(), varnames))
elif elem.isA(mx.PropertySet):
props = elem.getProperties()
if props:
propnames = " (" + ", ".join(map(
lambda x: "%s %s%s" % (x.getType(), x.getName(), getTarget(x)), props)) + ")"
else:
propnames = ""
names.append("%s%s" % (elem.getName(), propnames))
elif elem.isA(mx.LookGroup):
lks = elem.getLooks()
if lks:
names.append("%s (looks: %s)" % (elem.getName(), lks))
else:
names.append("%s (no looks)" % (elem.getName()))
elif elem.isA(mx.Look):
mas = ""
if resolve:
mtlassns = elem.getActiveMaterialAssigns()
else:
mtlassns = elem.getMaterialAssigns()
for mtlassn in mtlassns:
mas = mas + "\n\t MaterialAssign %s to%s" % (
mtlassn.getMaterial(), getGeoms(mtlassn, resolve))
pas = ""
if resolve:
propassns = elem.getActivePropertyAssigns()
else:
propassns = elem.getPropertyAssigns()
for propassn in propassns:
propertyname = propassn.getAttribute("property")
pas = pas + "\n\t PropertyAssign %s %s to%s" % (
propassn.getType(), propertyname, getGeoms(propassn, resolve))
psas = ""
if resolve:
propsetassns = elem.getActivePropertySetAssigns()
else:
propsetassns = elem.getPropertySetAssigns()
for propsetassn in propsetassns:
propertysetname = propsetassn.getAttribute("propertyset")
psas = psas + "\n\t PropertySetAssign %s to%s" % (
propertysetname, getGeoms(propsetassn, resolve))
varas = ""
if resolve:
variantassns = elem.getActiveVariantAssigns()
else:
variantassns = elem.getVariantAssigns()
for varassn in variantassns:
varas = varas + "\n\t VariantAssign %s from variantset %s" % (
varassn.getVariantString(), varassn.getVariantSetString())
visas = ""
if resolve:
visassns = elem.getActiveVisibilities()
else:
visassns = elem.getVisibilities()
for vis in visassns:
visstr = 'on' if vis.getVisible() else 'off'
visas = visas + "\n\t Set %s visibility%s %s to%s" % (
vis.getVisibilityType(), getViewerGeoms(vis), visstr, getGeoms(vis, resolve))
names.append("%s%s%s%s%s%s" %
(elem.getName(), mas, pas, psas, varas, visas))
else:
names.append(elem.getName())
return ":\n\t" + "\n\t".join(names)
def listShaderBindings(shader):
s = ''
for inp in shader.getInputs():
bname = inp.getName()
btype = inp.getType()
if inp.hasOutputString():
outname = inp.getOutputString()
if inp.hasNodeGraphString():
ngname = inp.getNodeGraphString()
s = s + '\n\t %s "%s" -> nodegraph "%s" output "%s"' % (btype, bname, ngname, outname)
else:
s = s + '\n\t %s "%s" -> output "%s"' % (btype, bname, outname)
else:
bval = getConvertedValue(inp)
s = s + '\n\t %s "%s" = %s' % (btype, bname, bval)
return s
def listNodedefInterface(nodedef):
s = ''
for inp in nodedef.getActiveInputs():
iname = inp.getName()
itype = inp.getType()
if s:
s = s + '\n\t'
s = s + ' %s input "%s"' % (itype, iname)
for tok in nodedef.getActiveTokens():
tname = tok.getName()
ttype = tok.getType()
if s:
s = s + '\n\t'
s = s + ' %s token "%s"' % (ttype, tname)
return s
def traverseInputs(node, port, depth):
s = ''
if node.isA(mx.Output):
parent = node.getConnectedNode()
s = s + traverseInputs(parent, "", depth+1)
else:
s = s + '%s%s -> %s %s "%s"' % (spc(depth), port,
node.getType(), node.getCategory(), node.getName())
ins = node.getActiveInputs()
for i in ins:
if i.hasInterfaceName():
intname = i.getInterfaceName()
s = s + \
'%s%s ^- %s interface "%s"' % (spc(depth+1),
i.getName(), i.getType(), intname)
elif i.hasValueString():
val = getConvertedValue(i)
s = s + \
'%s%s = %s value %s' % (
spc(depth+1), i.getName(), i.getType(), val)
else:
parent = i.getConnectedNode()
if parent:
s = s + traverseInputs(parent, i.getName(), depth+1)
toks = node.getActiveTokens()
for i in toks:
if i.hasInterfaceName():
intname = i.getInterfaceName()
s = s + \
'%s[T]%s ^- %s interface "%s"' % (
spc(depth+1), i.getName(), i.getType(), intname)
elif i.hasValueString():
val = i.getValueString()
s = s + \
'%s[T]%s = %s value "%s"' % (
spc(depth+1), i.getName(), i.getType(), val)
else:
s = s + \
'%s[T]%s error: no valueString' % (
spc(depth+1), i.getName())
return s
def pl(elem):
if len(elem) == 1:
return ""
else:
return "s"
def spc(depth):
return "\n\t " + ": "*depth
# Return a value string for the element, converting units if appropriate
def getConvertedValue(elem):
if elem.getType() in ["float", "vector2", "vector3", "vector4"]:
if elem.hasUnit():
u = elem.getUnit()
print ("[Unit for %s is %s]" % (elem.getName(), u))
if elem.hasUnitType():
utype = elem.getUnitType()
print ("[Unittype for %s is %s]" % (elem.getName(), utype))
# NOTDONE...
return elem.getValueString()
def getGeoms(elem, resolve):
s = ""
if elem.hasGeom():
if resolve:
s = s + ' geom "%s"' % elem.getActiveGeom()
else:
s = s + ' geom "%s"' % elem.getGeom()
if elem.hasCollectionString():
s = s + ' collection "%s"' % elem.getCollectionString()
return s
def getViewerGeoms(elem):
s = ""
if elem.hasViewerGeom():
s = s + ' viewergeom "%s"' % elem.getViewerGeom()
if elem.hasViewerCollection():
s = s + ' viewercollection "%s"' % elem.getViewerCollection()
if s:
s = " of" + s
return s
def getTarget(elem):
if elem.hasTarget():
return ' [target "%s"]' % elem.getTarget()
else:
return ""
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,429 @@
"""
Sample MaterialX ImageLoader implementation using OpenImageIO package.
This module provides a MaterialX-compatible ImageLoader implementation using OpenImageIO (OIIO).
The test will test loading an image, save it out, and optionally previewing it.
Steps:
1. Create an OIIOLoader which is derived from the ImageLoader interface class.
2. Create a new ImageHandler and register the loader with it.
3. Request to acquire an image using the ImageHandler. An EXR image is requested.
4. OIIOLoader will return supported extensions and match the requested image format.
5. As such the OIIOLoader will be requested to load in the EXR image, convert the
data and return a MaterialX Image.
6. Try to acquire the image again. This should returnt the cached MaterialX Image.
7. Save the image back to disk in the original format.
The image can optionally be previewed after load before save.
- Python Dependencies:
- OpenImageIO (version 3.0.6.1)
- API Docs can be found here: https://openimageio.readthedocs.io/en/v3.0.6.1/)
- numpy : For numerical operations on image data
- matplotlib : If image preview is desired.
"""
import ctypes
import os
import argparse
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("OIIOLoad")
try:
import MaterialX as mx
import MaterialX.PyMaterialXRender as mx_render
except ImportError:
logger.error("Required modules not found. Please install MaterialX.")
raise
try:
import OpenImageIO as oiio
import numpy as np
except ImportError:
logger.error("Required modules not found. Please install OpenImageIO and numpy.")
raise
have_matplot = False
try:
import matplotlib.pyplot as plt
have_matplot = True
except ImportError:
logger.warning("matplotlib module not found. Image preview display is disabled.")
class OiioImageLoader(mx_render.ImageLoader):
"""
A MaterialX ImageLoader implementation that uses OpenImageIO to read image files.
Inherits from MaterialX.ImageLoader and implements the required interface methods.
Supports common image formats like PNG, JPEG, TIFF, EXR, HDR, etc.
"""
def __init__(self):
"""
Initialize the OiioImageLoader and set supported extensions."""
super().__init__()
# Set all extensions supported by OpenImageIO. e.g.
# openexr:exr,sxr,mxr;tiff:tif,tiff,tx,env,sm,vsm;jpeg:jpg,jpe,jpeg,jif,jfif,jfi;bmp:bmp,dib;cineon:cin;dds:dds;dpx:dpx;fits:fits;hdr:hdr,rgbe;ico:ico;iff:iff,z;null:null,nul;png:png;pnm:ppm,pgm,pbm,pnm,pfm;psd:psd,pdd,psb;rla:rla;sgi:sgi,rgb,rgba,bw,int,inta;softimage:pic;targa:tga,tpic;term:term;webp:webp;zfile:zfile
self._extensions = set()
oiio_extensions = oiio.get_string_attribute("extension_list")
# Split string by ";"
for group in oiio_extensions.split(";"):
# Each group is like "openexr:exr,sxr,mxr"
if ":" in group:
_, exts = group.split(":", 1)
self._extensions.update(ext.strip() for ext in exts.split(","))
else:
self._extensions.update(ext.strip() for ext in group.split(","))
logger.debug(f"Cache supported extensions: {self._extensions}")
self.preview = False
self.identifier = "OpenImageIO Custom Image Loader"
self.color_space = {}
def supportedExtensions(self):
"""
Derived method to return a set of supported image file extensions.
"""
logger.info(f"OIIO supported extensions: {self._extensions}")
return self._extensions
def set_preview(self, value):
"""
Set whether to preview images when loading and saving
@param value: Boolean indicating whether to enable preview
"""
self.preview = value
def get_identifier(self):
return "OIIO Custom Loader"
def previewImage(self, title, data, width, height, nchannels, color_space):
"""
Utility method to preview an image using matplotlib.
Handles normalization and dtype for correct display.
@param title: Title for the preview window
@param data: Image data array
@param width: Image width
@param height: Image height
@param nchannels: Number of image channels
@param color_space: Color space of the image
"""
if not self.preview:
return
if have_matplot:
# If the image is float16 (half), convert to float32
if data.dtype == np.float16:
data = data.astype(np.float32)
flat = data.reshape(height, width, nchannels)
# Always display as RGB (first 3 channels or repeat if less)
if nchannels >= 3:
rgb = flat[..., :3]
else:
rgb = np.repeat(flat[..., :1], 3, axis=-1)
# Determine if normalization is needed
if np.issubdtype(flat.dtype, np.floating):
# If float, normalize to [0, 1] for display
rgb_disp = np.clip(rgb, 0.0, 1.0)
elif np.issubdtype(flat.dtype, np.integer):
# If integer, assume 8 or 16 bit, scale if needed
if flat.dtype == np.uint8:
rgb_disp = rgb # matplotlib expects [0,255] for uint8
elif flat.dtype == np.uint16:
# Scale 16-bit to 8-bit for display
rgb_disp = (rgb / 257).astype(np.uint8)
else:
# For other integer types, try to scale to [0,255]
rgb_disp = np.clip(rgb, 0, 255).astype(np.uint8)
else:
rgb_disp = rgb
# Set title bar text for the preview window
fig, ax = plt.subplots()
ax.imshow(rgb_disp)
ax.axis("off")
#fig.patch.set_facecolor("black")
fig.canvas.manager.set_window_title(title)
info = f"Dimensions:({width}x{height}), {nchannels} channels, type={data.dtype}, colorspace={color_space}"
fig.suptitle(title, fontsize=12)
plt.title(info, fontsize=9)
plt.show()
def loadImage(self, filePath):
"""
Load an image from the file system (MaterialX interface method).
@param filePath (MaterialX.FilePath): Path to the image file
@returns MaterialX.ImagePtr: MaterialX Image object or None if loading fails
"""
file_path_str = filePath.asString()
logger.info(f"Load using OIIO loader: {file_path_str}")
if not os.path.exists(file_path_str):
print(f"Error: File '{file_path_str}' does not exist")
return None
try:
# Open the image file
img_input = oiio.ImageInput.open(file_path_str)
if not img_input:
print(f"Error: Could not open '{file_path_str}' - {oiio.geterror()}")
return None
# Get image specifications
spec = img_input.spec()
color_space = spec.getattribute("oiio:ColorSpace")
logger.info(f"ColorSpace: {color_space}")
self.color_space[file_path_str] = color_space
# Check channel count
channels = spec.nchannels
if channels > 4:
channels = 4
# Determine MaterialX base type from OIIO format
base_type = self._oiio_to_materialx_type(spec.format.basetype)
if base_type is None:
img_input.close()
print(f"Error: Unsupported image format for '{file_path_str}'")
return None
# Create MaterialX image
mx_image = mx_render.Image.create(spec.width, spec.height, channels, base_type)
mx_image.createResourceBuffer()
logger.debug(f"Create buffer with width: {spec.width}, height: {spec.height}, channels: {spec.nchannels} -> {channels}")
# Read the image data using the correct OIIO Python API (returns a bytes object)
logger.debug(f"Reading image data from '{file_path_str}' with spec: {spec}")
data = img_input.read_image(0, 0, 0, channels, spec.format)
if len(data) > 0:
logger.debug(f"Done Reading image data from '{file_path_str}' with spec: {spec}")
else:
logger.error(f"Could not read image data.")
return None
self.previewImage("Loaded MaterialX Image", data, spec.width, spec.height, channels, color_space)
# Steps:
# - Copy the OIIO data into the MaterialX image resource buffer
resource_buffer_ptr = mx_image.getResourceBuffer()
bytes_per_channel = spec.format.size()
total_bytes = spec.width * spec.height * channels * bytes_per_channel
logger.info(f"Total bytes read in: {total_bytes} (width: {spec.width}, height: {spec.height}, channels: {channels}, format: {spec.format})")
try:
ctypes.memmove(resource_buffer_ptr, (ctypes.c_char * total_bytes).from_buffer_copy(data), total_bytes)
except Exception as e:
logger.error(f"Failed to update image resource buffer: {e}")
img_input.close()
return mx_image
except Exception as e:
print(f"Error loading image from '{file_path_str}': {str(e)}")
return None
return None
def saveImage(self, filePath, image, verticalFlip=False):
"""
@brief Saves an image to disk using OpenImageIO (OIIO).
@param filePath The file path where the image will be saved. Expected to have an asString() method.
@param image The MaterialX image object to save.
@param verticalFlip Whether to vertically flip the image before saving. (Currently unused.)
@return True if the image was saved successfully, False otherwise.
"""
filename = filePath.asString()
width = image.getWidth()
height = image.getHeight()
# Clamp to RGBA
src_channels = image.getChannelCount()
channels = min(src_channels, 4)
if src_channels > 4:
logger.warning(f"Image has {src_channels} channels. Saving only first {channels} (RGBA).")
mx_basetype = image.getBaseType()
oiio_format = self._materialx_to_oiio_type(mx_basetype)
logger.info(f"mx_basetype: {mx_basetype}, oiio_format: {oiio_format}, base_stride: {image.getBaseStride()}")
if oiio_format is None:
logger.error(f"Unsupported MaterialX base type for OIIO: {mx_basetype}")
return False
buffer_addr = image.getResourceBuffer()
np_type = self._materialx_type_to_np_type(mx_basetype)
if np_type is None:
logger.error(f"No NumPy dtype mapping for base type: {mx_basetype}")
return False
try:
# Steps:
# - Maps the MaterialX base type to OIIO and NumPy types.
# - Allocates a NumPy array for the pixel data.
# - Copies the raw buffer from the image into the NumPy array.
# - Optionally previews the image for debugging.
# - Creates an OIIO ImageOutput and writes the image to disk.
#
base_stride = image.getBaseStride() # bytes per channel element
total_bytes = width * height * src_channels * base_stride
buf_type = (ctypes.c_char * total_bytes)
buf = buf_type.from_address(buffer_addr)
np_buffer = np.frombuffer(buf, dtype=np_type)
# Validate total elements before reshape to catch mismatches early
expected_elems = width * height * src_channels
if np_buffer.size != expected_elems:
logger.error(f"Buffer element count mismatch: got {np_buffer.size}, expected {expected_elems}.")
return False
np_buffer = np_buffer.reshape((height, width, src_channels))
# Keep only up to RGBA
pixels = np_buffer[..., :channels].copy()
if verticalFlip:
logger.info("Applying vertical flip before saving image.")
pixels = np.flipud(pixels)
logger.info("Previewing image after load into Image and reload for save...")
# Remove "saved_" prefix if present
search_name = filename.replace("saved_", "")
color_space = "Unknown"
for key in self.color_space:
value = self.color_space[key]
path = os.path.basename(key)
if path in search_name:
color_space = value
logger.info(f"colorspace lookup for: {search_name}. list: {color_space}")
self.previewImage("OpenImageIO Output Image", pixels, width, height, channels, color_space)
except Exception as e:
logger.error(f"Error copying buffer to pixels: {e}")
return False
out = oiio.ImageOutput.create(filename)
if not out:
logger.error("Failed to create OIIO ImageOutput.")
return False
try:
spec = oiio.ImageSpec(width, height, channels, oiio_format)
out.open(filename, spec)
out.write_image(pixels)
logger.info(f"Image saved to {filename} (w={width}, h={height}, c={channels}, type={mx_basetype})")
out.close()
return True
except Exception as e:
logger.error(f"Failed to write image: {e}")
try:
out.close()
finally:
pass
return False
def _oiio_to_materialx_type(self, oiio_basetype):
"""Convert OIIO base type to MaterialX Image base type."""
type_mapping = {
oiio.UINT8: mx_render.BaseType.UINT8,
oiio.INT8: mx_render.BaseType.INT8,
oiio.UINT16: mx_render.BaseType.UINT16,
oiio.INT16: mx_render.BaseType.INT16,
oiio.HALF: mx_render.BaseType.HALF,
oiio.FLOAT: mx_render.BaseType.FLOAT
}
return_val = type_mapping.get(oiio_basetype, None)
logger.debug(f"OIIO to MaterialX type mapping: {return_val} from {oiio_basetype}")
return return_val
def _materialx_to_oiio_type(self, mx_basetype):
"""Convert MaterialX Image base type to OIIO type."""
type_mapping = {
mx_render.BaseType.UINT8: oiio.UINT8,
mx_render.BaseType.UINT16: oiio.UINT16,
mx_render.BaseType.INT8: oiio.INT8,
mx_render.BaseType.INT16: oiio.INT16,
mx_render.BaseType.HALF: oiio.HALF,
mx_render.BaseType.FLOAT: oiio.FLOAT,
}
return_val = type_mapping.get(mx_basetype, None)
logger.debug(f"MaterialX type mapping: {mx_basetype} to {return_val}")
return return_val
def _materialx_type_to_np_type(self, mx_basetype):
"""Map MaterialX base type to NumPy dtype with explicit widths."""
type_mapping = {
mx_render.BaseType.UINT8: np.uint8,
mx_render.BaseType.UINT16: np.uint16,
mx_render.BaseType.INT8: np.int8,
mx_render.BaseType.INT16: np.int16,
mx_render.BaseType.HALF: np.float16,
mx_render.BaseType.FLOAT: np.float32,
}
return type_mapping.get(mx_basetype, None)
def test_load_save():
"""
Example usage of the OiioImageLoader class with MaterialX ImageHandler.
"""
parser = argparse.ArgumentParser(description="MaterialX OIIO Image Handler")
parser.add_argument("path", help="Path to the image file")
parser.add_argument("--flip", action="store_true", help="Flip the image vertically")
parser.add_argument("--preview", action="store_true", help="Preview the image before saving")
args = parser.parse_args()
test_image_path = args.path
if not os.path.exists(test_image_path):
logger.error(f"Image file not found: {test_image_path}")
return
# Create MaterialX handler with custom OIIO image loader
loader = OiioImageLoader()
loader.set_preview(args.preview)
handler = mx_render.ImageHandler.create(loader)
logger.info(f"Created image handler with loader ({loader.get_identifier()}): {handler is not None}")
handler.addLoader(loader)
mx_filepath = mx.FilePath(test_image_path)
# Load image using handler API
logger.info('-'*45)
logger.info(f"Loading image from path: {mx_filepath.asString()}")
mx_image = handler.acquireImage(mx_filepath)
if mx_image:
# Q: How to check for failed image load as you
# get back a 1x1 pixel image.
if mx_image.getWidth() == 1 and mx_image.getHeight() == 1:
logger.warning("Failed to load image. Got 1x1 pixel image returned")
return
logger.info(f"MaterialX Image loaded via Image Handler:")
logger.info(f" Dimensions: {mx_image.getWidth()}x{mx_image.getHeight()}")
logger.info(f" Channels: {mx_image.getChannelCount()}")
logger.info(f" Base type: {mx_image.getBaseType()}")
# Save image using handler API (to a new file)
logger.info('-'*45)
# Retrieve cached image
mx_image = handler.acquireImage(mx_filepath)
if mx_image:
out_path = mx.FilePath("saved_" + os.path.basename(test_image_path))
if handler.saveImage(out_path, mx_image, verticalFlip=args.flip):
logger.info(f"MaterialX Image saved to {out_path.asString()}")
else:
logger.error("Failed to save image.")
else:
logger.error("Failed to acquire image for saving.")
else:
logger.error("Failed to load image.")
if __name__ == "__main__":
test_load_save()

View File

@@ -0,0 +1,514 @@
#!/usr/bin/env python
"""
pybind11 documentation insertion tool.
Extracts documentation from Doxygen XML and inserts it into pybind11 bindings
using string matching via signature lookup table.
Logic:
- Builds a multi-key lookup for all functions (MaterialX::, mx::, Class::method, method)
- Handles free functions without <qualifiedname> by assuming MaterialX namespace
- Adds class context tracking to correctly document lambda-based bindings
- Supports .def(...) and .def_static(...); skips .def_readonly_static(...)
"""
import argparse
import re
import json
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Dict, Optional
# Defaults (can be overridden by CLI)
DOXYGEN_XML_DIR = Path("build/documents/doxygen_xml")
PYBIND_DIR = Path("source/PyMaterialX")
class DocExtractor:
"""Extracts documentation from Doxygen XML files and builds a lookup table."""
def __init__(self, xml_dir: Path):
self.xml_dir = xml_dir
self.class_docs: Dict[str, str] = {}
self.func_docs: Dict[str, Dict] = {}
# Multi-key lookup: all name variants point to the same doc
self.func_lookup: Dict[str, Dict] = {}
def extract(self):
if not self.xml_dir.exists():
raise FileNotFoundError(f"Doxygen XML directory not found: {self.xml_dir}")
for xml_file in self.xml_dir.glob("*.xml"):
self._process_xml_file(xml_file)
self._build_lookup_table()
print(f"Extracted {len(self.class_docs)} classes and {len(self.func_docs)} functions")
print(f"Built lookup table with {len(self.func_lookup)} keys")
def _process_xml_file(self, xml_file: Path):
tree = ET.parse(xml_file)
root = tree.getroot()
# Class / struct documentation
for compound in root.findall(".//compounddef[@kind='class']") + root.findall(".//compounddef[@kind='struct']"):
self._extract_class_doc(compound)
# Function documentation
for member in root.findall(".//memberdef[@kind='function']"):
self._extract_func_doc(member)
def _extract_class_doc(self, compound):
name = self._get_text(compound.find("compoundname"))
brief = self._get_text(compound.find("briefdescription/para"))
detail = self._extract_detail(compound.find("detaileddescription"))
doc = "\n\n".join(filter(None, [brief, detail]))
if doc:
normalized = self._normalize_name(name)
self.class_docs[normalized] = doc
def _extract_func_doc(self, member):
name = self._get_text(member.find("name"))
qualified = self._get_text(member.find("qualifiedname"))
# Many free functions have no <qualifiedname>; use the bare name
# and normalize to MaterialX::name so lookups can resolve.
if not qualified and name:
qualified = name
if not qualified:
return
brief = self._get_text(member.find("briefdescription/para"))
detail = self._extract_detail(member.find("detaileddescription"))
params = self._extract_params(member)
returns = self._get_text(member.find(".//simplesect[@kind='return']"))
normalized = self._normalize_name(qualified)
self.func_docs[normalized] = {
"brief": brief,
"detail": detail,
"params": params,
"returns": returns,
}
def _build_lookup_table(self):
for qualified_name, doc in self.func_docs.items():
for variant in self._generate_name_variants(qualified_name):
if variant not in self.func_lookup:
self.func_lookup[variant] = doc
def _generate_name_variants(self, qualified_name: str) -> list:
variants = [qualified_name]
parts = qualified_name.split("::")
# Class::method
if len(parts) >= 2:
variants.append("::".join(parts[-2:]))
# method
if len(parts) >= 1:
variants.append(parts[-1])
# mx:: variant if MaterialX::
if qualified_name.startswith("MaterialX::"):
mx_variant = qualified_name.replace("MaterialX::", "mx::", 1)
variants.append(mx_variant)
if len(parts) >= 3:
variants.append(f"mx::{parts[-2]}::{parts[-1]}")
return variants
def _normalize_name(self, name: str) -> str:
if not name:
return name
return name if name.startswith("MaterialX::") else f"MaterialX::{name}"
def _get_text(self, elem) -> str:
if elem is None:
return ""
text = "".join(elem.itertext())
return re.sub(r"\s+", " ", text).strip()
def _extract_detail(self, elem, exclude_tags={"parameterlist", "simplesect"}) -> str:
if elem is None:
return ""
parts = []
for para in elem.findall("para"):
if not any(para.find(tag) is not None for tag in exclude_tags):
t = self._get_text(para)
if t:
parts.append(t)
return "\n\n".join(parts)
def _extract_params(self, member) -> Dict[str, str]:
params = {}
for param_item in member.findall(".//parameterlist[@kind='param']/parameteritem"):
name = self._get_text(param_item.find("parameternamelist/parametername"))
desc = self._get_text(param_item.find("parameterdescription"))
if name:
params[name] = desc
return params
class DocInserter:
"""Inserts documentation into pybind11 binding files."""
def __init__(self, extractor: DocExtractor, pybind_dir: Path, force_replace: bool = False):
self.extractor = extractor
self.pybind_dir = pybind_dir
self.force_replace = force_replace
self.class_pattern = re.compile(r"py::class_<")
self.def_pattern = re.compile(r"\.def(?:_static)?\s*\(")
# Match .def and .def_static; skip .def_readonly_static (constants)
self.def_pattern = re.compile(r"\.def(?:_static)?\s*\(")
self.skip_pattern = re.compile(r"\.def_readonly_static\s*\(")
def process_all_files(self):
cpp_files = list(self.pybind_dir.rglob("*.cpp"))
patched = 0
for cpp_file in cpp_files:
if self._process_file(cpp_file):
patched += 1
print(f"\nProcessed {len(cpp_files)} files, patched {patched}")
def _process_file(self, cpp_file: Path) -> bool:
content = cpp_file.read_text(encoding="utf-8")
original = content
content = self._insert_class_docs(content)
content = self._insert_method_docs(content)
if content != original:
cpp_file.write_text(content, encoding="utf-8")
print(f" - {cpp_file.relative_to(self.pybind_dir.parent)}")
return True
else:
print(f" - {cpp_file.relative_to(self.pybind_dir.parent)}")
return False
def _insert_class_docs(self, content: str) -> str:
result = []
pos = 0
for match in self.class_pattern.finditer(content):
result.append(content[pos:match.start()])
start = match.start()
template_end = self._find_template_end(content, start)
if template_end == -1:
result.append(content[start:match.end()])
pos = match.end()
continue
paren_start = content.find('(', template_end)
if paren_start == -1:
result.append(content[start:match.end()])
pos = match.end()
continue
paren_end = self._find_matching_paren(content, paren_start)
if paren_end == -1:
result.append(content[start:match.end()])
pos = match.end()
continue
args_text = content[paren_start + 1:paren_end]
class_name = self._extract_class_name(args_text)
if class_name:
doc = self.extractor.class_docs.get(self.extractor._normalize_name(class_name))
if doc:
args = self._split_args(args_text)
if len(args) >= 3 and not self.force_replace:
result.append(content[start:paren_end + 1])
pos = paren_end + 1
continue
escaped = self._escape_for_cpp(doc)
if len(args) >= 3 and self.force_replace:
new_args = args[:2] + [f'"{escaped}"'] + args[3:]
result.append(content[start:paren_start + 1])
result.append(", ".join(new_args))
result.append(")")
else:
result.append(content[start:paren_end])
result.append(f', "{escaped}")')
pos = paren_end + 1
continue
result.append(content[start:paren_end + 1])
pos = paren_end + 1
result.append(content[pos:])
return "".join(result)
def _insert_method_docs(self, content: str) -> str:
# Build a map of line numbers to class contexts
class_contexts = self._extract_class_contexts(content)
result = []
pos = 0
for match in self.def_pattern.finditer(content):
if self.skip_pattern.match(content, match.start()):
continue
result.append(content[pos:match.start()])
start = match.start()
paren_start = content.find('(', start)
if paren_start == -1:
result.append(content[start:match.end()])
pos = match.end()
continue
paren_end = self._find_matching_paren(content, paren_start)
if paren_end == -1:
result.append(content[start:match.end()])
pos = match.end()
continue
args_text = content[paren_start + 1:paren_end]
args = self._split_args(args_text)
if len(args) < 2:
result.append(content[start:paren_end + 1])
pos = paren_end + 1
continue
has_doc = self._has_docstring(args)
if has_doc and not self.force_replace:
result.append(content[start:paren_end + 1])
pos = paren_end + 1
continue
callable_ref = args[1].strip()
current_line = content[:start].count('\n')
class_context = class_contexts.get(current_line)
doc_entry = self._find_doc_for_callable(callable_ref, class_context)
if doc_entry:
docstring = self._build_docstring(doc_entry)
escaped = self._escape_for_cpp(docstring)
if has_doc and self.force_replace:
doc_idx = self._find_docstring_arg_index(args)
if doc_idx is not None:
new_args = args[:doc_idx] + [f'"{escaped}"'] + args[doc_idx + 1:]
result.append(content[start:paren_start + 1])
result.append(", ".join(new_args))
result.append(")")
pos = paren_end + 1
continue
result.append(content[start:paren_end])
result.append(f', "{escaped}")')
pos = paren_end + 1
continue
result.append(content[start:paren_end + 1])
pos = paren_end + 1
result.append(content[pos:])
return "".join(result)
def _extract_class_contexts(self, content: str) -> Dict[int, str]:
contexts = {}
for match in self.class_pattern.finditer(content):
start = match.start()
template_end = self._find_template_end(content, start)
if template_end == -1:
continue
template_start = content.find('<', start) + 1
template_content = content[template_start:template_end - 1]
class_type = template_content.split(',')[0].strip()
class_name = class_type.split('::')[-1] if '::' in class_type else class_type
start_line = content[:start].count('\n')
end_pos = content.find(';', start)
if end_pos != -1:
end_line = content[:end_pos].count('\n')
for line in range(start_line, end_line + 1):
contexts[line] = class_name
return contexts
def _find_doc_for_callable(self, callable_ref: str, class_context: Optional[str] = None) -> Optional[Dict]:
callable_ref = callable_ref.strip()
# Function pointers like &mx::Class::method or &MaterialX::name
if callable_ref.startswith('&'):
name = callable_ref[1:].strip()
name = re.sub(r'[,\s]+$', '', name)
return self.extractor.func_lookup.get(name)
# Lambdas: look for elem.method( or obj->method(
method_match = re.search(r'[\.\->](\w+)\s*\(', callable_ref)
if method_match:
method_name = method_match.group(1)
if class_context:
for prefix in ("", "mx::", "MaterialX::"):
qualified = f"{prefix}{class_context}::{method_name}" if prefix else f"{class_context}::{method_name}"
doc = self.extractor.func_lookup.get(qualified)
if doc:
return doc
return self.extractor.func_lookup.get(method_name)
return None
def _build_docstring(self, doc_entry: Dict) -> str:
parts = []
if doc_entry.get("brief"):
parts.append(doc_entry["brief"])
if doc_entry.get("detail"):
parts.append(doc_entry["detail"])
params = doc_entry.get("params", {})
if params:
param_lines = ["Args:"]
for name, desc in params.items():
param_lines.append(f" {name}: {desc}" if desc else f" {name}:")
parts.append("\n".join(param_lines))
if doc_entry.get("returns"):
parts.append(f"Returns:\n {doc_entry['returns']}")
return "\n\n".join(parts)
def _escape_for_cpp(self, s: str) -> str:
if not s:
return ""
s = s.replace("\\", "\\\\").replace('"', '\\"')
s = s.replace("\n", "\\n")
return s
def _find_template_end(self, content: str, start: int) -> int:
pos = content.find('<', start)
if pos == -1:
return -1
depth = 1
i = pos + 1
in_string = False
while i < len(content) and depth > 0:
c = content[i]
if c == '"' and content[i - 1] != '\\':
in_string = not in_string
elif not in_string:
if c == '<':
depth += 1
elif c == '>':
depth -= 1
i += 1
return i if depth == 0 else -1
def _find_matching_paren(self, content: str, start: int) -> int:
depth = 0
in_string = False
escape = False
for i in range(start, len(content)):
c = content[i]
if escape:
escape = False
continue
if c == '\\':
escape = True
continue
if c == '"':
in_string = not in_string
continue
if not in_string:
if c == '(':
depth += 1
elif c == ')':
depth -= 1
if depth == 0:
return i
return -1
def _split_args(self, args_text: str) -> list:
args = []
current = []
depth = 0
in_string = False
escape = False
for c in args_text:
if escape:
current.append(c)
escape = False
continue
if c == '\\':
current.append(c)
escape = True
continue
if c == '"':
in_string = not in_string
current.append(c)
continue
if not in_string:
if c in '(<':
depth += 1
elif c in ')>':
depth -= 1
elif c == ',' and depth == 0:
args.append("".join(current).strip())
current = []
continue
current.append(c)
if current:
args.append("".join(current).strip())
return args
def _extract_class_name(self, args_text: str) -> Optional[str]:
args = self._split_args(args_text)
if len(args) >= 2:
return args[1].strip().strip('"')
return None
def _has_docstring(self, args: list) -> bool:
for arg in args[2:]:
a = arg.strip()
if not a.startswith("py::arg") and a.startswith('"'):
return True
return False
def _find_docstring_arg_index(self, args: list) -> Optional[int]:
for i, arg in enumerate(args[2:], start=2):
a = arg.strip()
if not a.startswith("py::arg") and a.startswith('"'):
return i
return None
def main():
parser = argparse.ArgumentParser(description="Extract Doxygen docs and insert into pybind11 bindings (simplified)")
parser.add_argument("-d", "--doxygen_xml_dir", type=Path, default=Path("build/documents/doxygen_xml"), help="Path to Doxygen XML output directory")
parser.add_argument("-p", "--pybind_dir", type=Path, default=Path("source/PyMaterialX"), help="Path to pybind11 bindings directory")
parser.add_argument("-f", "--force", action="store_true", help="Force replace existing docstrings")
parser.add_argument("-j", "--write_json", action="store_true", help="Write extracted docs to JSON files")
args = parser.parse_args()
if not args.doxygen_xml_dir.exists():
print(f"Error: Doxygen XML directory not found: {args.doxygen_xml_dir}")
return 1
if not args.pybind_dir.exists():
print(f"Error: Pybind directory not found: {args.pybind_dir}")
return 1
print("Extracting documentation from Doxygen XML...")
extractor = DocExtractor(args.doxygen_xml_dir)
extractor.extract()
if args.write_json:
print("\nWriting JSON files...")
Path("class_docs.json").write_text(json.dumps(extractor.class_docs, indent=2), encoding="utf-8")
Path("func_docs.json").write_text(json.dumps(extractor.func_docs, indent=2), encoding="utf-8")
print(" - class_docs.json")
print(" - func_docs.json")
print(f"\n{'Replacing' if args.force else 'Inserting'} documentation in pybind11 files...")
inserter = DocInserter(extractor, args.pybind_dir, args.force)
inserter.process_all_files()
print("\nDone!")
return 0
if __name__ == "__main__":
exit(main())

View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python
'''
Generate a baked translated version of each material in the input document, using the ShaderTranslator class in the MaterialXShaderGen library
and the TextureBaker class in the MaterialXRenderGlsl library.
'''
import sys, os, argparse
from sys import platform
import MaterialX as mx
from MaterialX import PyMaterialXGenShader as mx_gen_shader
from MaterialX import PyMaterialXRender as mx_render
from MaterialX import PyMaterialXRenderGlsl as mx_render_glsl
if platform == "darwin":
from MaterialX import PyMaterialXRenderMsl as mx_render_msl
def main():
parser = argparse.ArgumentParser(description="Generate a translated baked version of each material in the input document.")
parser.add_argument("--width", dest="width", type=int, default=0, help="Specify an optional width for baked textures (defaults to the maximum image height in the source document).")
parser.add_argument("--height", dest="height", type=int, default=0, help="Specify an optional height for baked textures (defaults to the maximum image width in the source document).")
parser.add_argument("--hdr", dest="hdr", action="store_true", help="Bake images with high dynamic range (e.g. in HDR or EXR format).")
parser.add_argument("--path", dest="paths", action='append', nargs='+', help="An additional absolute search path location (e.g. '/projects/MaterialX')")
parser.add_argument("--library", dest="libraries", action='append', nargs='+', help="An additional relative path to a custom data library folder (e.g. 'libraries/custom')")
parser.add_argument('--writeDocumentPerMaterial', dest='writeDocumentPerMaterial', type=mx.stringToBoolean, default=True, help='Specify whether to write baked materials to separate MaterialX documents. Default is True')
if platform == "darwin":
parser.add_argument("--glsl", dest="useGlslBackend", default=False, type=bool, help="Set to True to use GLSL backend (default = Metal).")
parser.add_argument(dest="inputFilename", help="Filename of the input document.")
parser.add_argument(dest="outputFilename", help="Filename of the output document.")
parser.add_argument(dest="destShader", help="Destination shader for translation")
opts = parser.parse_args()
# Load standard and custom data libraries.
stdlib = mx.createDocument()
searchPath = mx.getDefaultDataSearchPath()
searchPath.append(os.path.dirname(opts.inputFilename))
libraryFolders = []
if opts.paths:
for pathList in opts.paths:
for path in pathList:
searchPath.append(path)
if opts.libraries:
for libraryList in opts.libraries:
for library in libraryList:
libraryFolders.append(library)
libraryFolders.extend(mx.getDefaultDataLibraryFolders())
mx.loadLibraries(libraryFolders, searchPath, stdlib)
# Read and validate the source document.
doc = mx.createDocument()
try:
mx.readFromXmlFile(doc, opts.inputFilename)
doc.setDataLibrary(stdlib)
except mx.ExceptionFileMissing as err:
print(err)
sys.exit(0)
valid, msg = doc.validate()
if not valid:
print("Validation warnings for input document:")
print(msg)
# Check the document for a UDIM set.
udimSetValue = doc.getGeomPropValue(mx.UDIM_SET_PROPERTY)
udimSet = udimSetValue.getData() if udimSetValue else []
# Compute baking resolution from the source document.
imageHandler = mx_render.ImageHandler.create(mx_render.StbImageLoader.create())
imageHandler.setSearchPath(searchPath)
if udimSet:
resolver = doc.createStringResolver()
resolver.setUdimString(udimSet[0])
imageHandler.setFilenameResolver(resolver)
imageVec = imageHandler.getReferencedImages(doc)
bakeWidth, bakeHeight = mx_render.getMaxDimensions(imageVec)
# Apply baking resolution settings.
if opts.width > 0:
bakeWidth = opts.width
if opts.height > 0:
bakeHeight = opts.height
bakeWidth = max(bakeWidth, 4)
bakeHeight = max(bakeHeight, 4)
# Translate materials between shading models
translator = mx_gen_shader.ShaderTranslator.create()
try:
translator.translateAllMaterials(doc, opts.destShader)
except mx.Exception as err:
print(err)
sys.exit(0)
# Bake translated materials to flat textures.
baseType = mx_render.BaseType.FLOAT if opts.hdr else mx_render.BaseType.UINT8
if platform == "darwin" and not opts.useGlslBackend:
baker = mx_render_msl.TextureBaker.create(bakeWidth, bakeHeight, baseType)
else:
baker = mx_render_glsl.TextureBaker.create(bakeWidth, bakeHeight, baseType)
baker.writeDocumentPerMaterial(opts.writeDocumentPerMaterial)
baker.bakeAllMaterials(doc, searchPath, opts.outputFilename)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python
'''
Generate the "NodeGraphs.mtlx" example file programmatically.
'''
import MaterialX as mx
def main():
doc = mx.createDocument()
#
# Nodegraph example 1
#
ng1 = doc.addNodeGraph("NG_example1")
img1 = ng1.addNode("image", "img1", "color3")
# Because filenames look like string types, it is necessary to explicitly declare
# this parameter value as type "filename".
img1.setInputValue("file", "layer1.tif", "filename")
img2 = ng1.addNode("image", "img2", "color3")
img2.setInputValue("file", "layer2.tif", "filename")
img3 = ng1.addNode("image", "img3", "float")
img3.setInputValue("file", "mask1.tif", "filename")
n0 = ng1.addNode("mix", "n0", "color3")
# To connect an input to another node, you must first add the input with the expected
# type, and then setConnectedNode() that input to the desired Node object.
infg = n0.addInput("fg", "color3")
infg.setConnectedNode(img1)
inbg = n0.addInput("bg", "color3")
inbg.setConnectedNode(img2)
inmx = n0.addInput("mix", "float")
inmx.setConnectedNode(img3)
n1 = ng1.addNode("multiply", "n1", "color3")
inp1 = n1.addInput("in1", "color3")
inp1.setConnectedNode(n0)
inp2 = n1.setInputValue("in2", 0.22)
nout = ng1.addOutput("diffuse", "color3")
nout.setConnectedNode(n1)
#
# Nodegraph example 3
#
ng3 = doc.addNodeGraph("NG_example3")
img1 = ng3.addNode("image", "img1", "color3")
img1.setInputValue("file", "<diff_albedo>", "filename")
img2 = ng3.addNode("image", "img2", "color3")
img2.setInputValue("file", "<dirt_albedo>", "filename")
img3 = ng3.addNode("image", "img3", "float")
img3.setInputValue("file", "<areamask>", "filename")
img4 = ng3.addNode("image", "img4", "float")
img4.setInputValue("file", "<noisemask>", "filename")
n5 = ng3.addNode("constant", "n5", "color3")
# For colorN, vectorN or matrix types, use the appropriate mx Type constructor.
n5.setInputValue("value", mx.Color3(0.8,1.0,1.3))
n6 = ng3.addNode("multiply", "n6", "color3")
inp1 = n6.addInput("in1", "color3")
inp1.setConnectedNode(n5)
inp2 = n6.addInput("in2", "color3")
inp2.setConnectedNode(img1)
n7 = ng3.addNode("contrast", "n7", "color3")
inp = n7.addInput("in", "color3")
inp.setConnectedNode(img2)
n7.setInputValue("amount", 0.2)
n7.setInputValue("pivot", 0.5)
n8 = ng3.addNode("mix", "n8", "color3")
infg = n8.addInput("fg", "color3")
infg.setConnectedNode(n7)
inbg = n8.addInput("bg", "color3")
inbg.setConnectedNode(n6)
inmx = n8.addInput("mix", "float")
inmx.setConnectedNode(img3)
t1 = ng3.addNode("texcoord", "t1", "vector2")
m1 = ng3.addNode("multiply", "m1", "vector2")
inp1 = m1.addInput("in1", "vector2")
inp1.setConnectedNode(t1)
m1.setInputValue("in2", 0.003)
# If limited floating-point precision results in output value strings like "0.00299999",
# you could instead write this as a ValueString (must add the input to the node first):
# inp2 = m1.addInput("in2", "float")
# inp2.setValueString("0.003")
n9 = ng3.addNode("noise2d", "n9", "color3")
intx = n9.addInput("texcoord", "vector2")
intx.setConnectedNode(m1)
n9.setInputValue("amplitude", mx.Vector3(0.05,0.04,0.06))
n10 = ng3.addNode("inside", "n10", "color3")
inmask = n10.addInput("mask", "float")
inmask.setConnectedNode(img4)
inp = n10.addInput("in", "color3")
inp.setConnectedNode(n9)
n11 = ng3.addNode("add", "n11", "color3")
inp1 = n11.addInput("in1", "color3")
inp1.setConnectedNode(n10)
inp2 = n11.addInput("in2", "color3")
inp2.setConnectedNode(n8)
nout1 = ng3.addOutput("albedo", "color3")
nout1.setConnectedNode(n11)
nout2 = ng3.addOutput("areamask", "float")
nout2.setConnectedNode(img3)
# It is not necessary to validate a document before writing but it's nice
# to know for sure. And you can validate any element (and its children)
# independently, not just the whole document.
rc = ng1.validate()
if (len(rc) >= 1 and rc[0]):
print("Nodegraph %s is valid." % ng1.getName())
else:
print("Nodegraph %s is NOT valid: %s" % (ng1.getName(), str(rc[1])))
rc = ng3.validate()
if (len(rc) >= 1 and rc[0]):
print("Nodegraph %s is valid." % ng3.getName())
else:
print("Nodegraph %s is NOT valid: %s" % (ng3.getName(), str(rc[1])))
outfile = "myNodeGraphs.mtlx"
mx.writeToXmlFile(doc, outfile)
print("Wrote nodegraphs to %s" % outfile)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,18 @@
from setuptools import setup
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
def getRecursivePackageData(root):
packageData = []
for dirpath, dirnames, filenames in os.walk(root):
relpath = os.path.relpath(dirpath, root)
packageData.append(os.path.join(relpath, '*.*'))
return packageData
setup(name='MaterialX',
url='www.materialx.org',
version='${MATERIALX_MAJOR_VERSION}.${MATERIALX_MINOR_VERSION}.${MATERIALX_BUILD_VERSION}',
packages=['MaterialX'],
package_data={'MaterialX' : getRecursivePackageData('MaterialX')},
zip_safe = False)