Files
SDL3CPlusPlus/scripts/validate_shaders.py
2026-01-08 03:20:44 +00:00

326 lines
13 KiB
Python
Executable File
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Mega-strict shader pipeline validator
Catches issues before they reach the GPU driver
"""
import re
import sys
from dataclasses import dataclass
from typing import List, Dict, Set, Optional, Tuple
@dataclass
class ShaderAttribute:
location: int
type: str
name: str
@dataclass
class VertexLayoutAttribute:
location: int
type: str
name: str
size: int # bytes
@dataclass
class ValidationError:
severity: str # 'ERROR', 'WARNING', 'INFO'
message: str
shader_name: Optional[str] = None
line_number: Optional[int] = None
class ShaderValidator:
def __init__(self):
self.errors: List[ValidationError] = []
def error(self, msg: str, shader: str = None, line: int = None):
self.errors.append(ValidationError('ERROR', msg, shader, line))
def warning(self, msg: str, shader: str = None, line: int = None):
self.errors.append(ValidationError('WARNING', msg, shader, line))
def info(self, msg: str, shader: str = None, line: int = None):
self.errors.append(ValidationError('INFO', msg, shader, line))
def extract_shader_inputs(self, glsl_source: str, shader_name: str) -> List[ShaderAttribute]:
"""Extract vertex shader input attributes"""
inputs = []
# Match: layout (location = N) in type name;
pattern = r'layout\s*\(\s*location\s*=\s*(\d+)\s*\)\s+in\s+(\w+)\s+(\w+)\s*;'
for line_num, line in enumerate(glsl_source.split('\n'), 1):
match = re.search(pattern, line)
if match:
location = int(match.group(1))
attr_type = match.group(2)
name = match.group(3)
inputs.append(ShaderAttribute(location, attr_type, name))
self.info(f"Found input: location={location}, type={attr_type}, name={name}",
shader_name, line_num)
return inputs
def extract_shader_outputs(self, glsl_source: str, shader_name: str) -> List[ShaderAttribute]:
"""Extract shader output attributes"""
outputs = []
# Match: layout (location = N) out type name;
pattern = r'layout\s*\(\s*location\s*=\s*(\d+)\s*\)\s+out\s+(\w+)\s+(\w+)\s*;'
for line_num, line in enumerate(glsl_source.split('\n'), 1):
match = re.search(pattern, line)
if match:
location = int(match.group(1))
attr_type = match.group(2)
name = match.group(3)
outputs.append(ShaderAttribute(location, attr_type, name))
return outputs
def validate_vertex_layout_match(self, shader_inputs: List[ShaderAttribute],
vertex_layout: List[VertexLayoutAttribute],
shader_name: str):
"""Validate that shader inputs match the vertex layout"""
# Check that all shader inputs have corresponding vertex layout entries
shader_locs = {attr.location: attr for attr in shader_inputs}
layout_locs = {attr.location: attr for attr in vertex_layout}
# Check for missing attributes in layout
for loc, shader_attr in shader_locs.items():
if loc not in layout_locs:
self.error(f"Shader '{shader_name}' expects input at location {loc} "
f"({shader_attr.name}: {shader_attr.type}) but vertex layout doesn't provide it")
# Check for type mismatches
for loc in shader_locs.keys() & layout_locs.keys():
shader_attr = shader_locs[loc]
layout_attr = layout_locs[loc]
# Type compatibility check
shader_type = shader_attr.type
layout_type = layout_attr.type
if not self.types_compatible(shader_type, layout_type):
self.error(f"Type mismatch at location {loc}: shader expects {shader_type} "
f"but vertex layout provides {layout_type}")
# Check for location gaps
if shader_locs:
max_loc = max(shader_locs.keys())
for i in range(max_loc):
if i not in shader_locs and i in layout_locs:
self.warning(f"Vertex layout provides unused attribute at location {i}")
def types_compatible(self, shader_type: str, layout_type: str) -> bool:
"""Check if shader type and layout type are compatible"""
# Mapping of GLSL types to expected vertex layout types
type_map = {
'vec2': 'float2',
'vec3': 'float3',
'vec4': 'float4',
'float': 'float',
'int': 'int',
'ivec2': 'int2',
'ivec3': 'int3',
'ivec4': 'int4',
}
expected_layout_type = type_map.get(shader_type)
return expected_layout_type == layout_type or layout_type == shader_type
def validate_vertex_buffer_stride(self, vertex_layout: List[VertexLayoutAttribute],
actual_vertex_size: int):
"""Validate that the vertex layout stride matches the actual vertex struct size"""
# Calculate expected stride from layout
expected_stride = sum(attr.size for attr in vertex_layout)
if expected_stride != actual_vertex_size:
self.error(f"Vertex layout stride mismatch: layout expects {expected_stride} bytes "
f"but actual vertex struct is {actual_vertex_size} bytes")
self.info(f"Layout breakdown:")
for attr in vertex_layout:
self.info(f" {attr.name}: {attr.size} bytes")
def validate_location_continuity(self, shader_inputs: List[ShaderAttribute], shader_name: str):
"""Warn about non-continuous location assignments"""
if not shader_inputs:
return
locations = sorted([attr.location for attr in shader_inputs])
expected = list(range(len(locations)))
if locations != expected:
self.warning(f"Shader '{shader_name}' has non-continuous location assignments: {locations} "
f"(expected: {expected}). This may cause issues with some drivers.",
shader_name)
def validate_spir_v_compatibility(self, glsl_source: str, shader_name: str):
"""Check for common SPIR-V compilation issues"""
lines = glsl_source.split('\n')
for line_num, line in enumerate(lines, 1):
# Check for inputs/outputs without location qualifiers (required for SPIR-V)
if re.search(r'\bin\s+\w+\s+\w+\s*;', line) and 'layout' not in line:
if '#version' not in line and '//' not in line:
self.error(f"SPIR-V requires 'layout(location=N)' for all inputs/outputs",
shader_name, line_num)
if re.search(r'\bout\s+\w+\s+\w+\s*;', line) and 'layout' not in line:
if '#version' not in line and '//' not in line and 'gl_' not in line:
self.error(f"SPIR-V requires 'layout(location=N)' for all inputs/outputs",
shader_name, line_num)
def validate_interface_matching(self, vertex_outputs: List[ShaderAttribute],
fragment_inputs: List[ShaderAttribute],
vs_name: str, fs_name: str):
"""Validate that vertex shader outputs match fragment shader inputs"""
vs_locs = {attr.location: attr for attr in vertex_outputs}
fs_locs = {attr.location: attr for attr in fragment_inputs}
# Fragment shader inputs should match vertex shader outputs
for loc, fs_attr in fs_locs.items():
if loc not in vs_locs:
self.error(f"Fragment shader '{fs_name}' expects input at location {loc} "
f"({fs_attr.name}) but vertex shader '{vs_name}' doesn't output it")
else:
vs_attr = vs_locs[loc]
if vs_attr.type != fs_attr.type:
self.error(f"Type mismatch at location {loc}: VS outputs {vs_attr.type} "
f"but FS expects {fs_attr.type}")
def validate_attribute_name_conventions(self, shader_inputs: List[ShaderAttribute],
shader_name: str):
"""Validate attribute naming conventions"""
# Expected prefixes for different shader stages
valid_prefixes = ['i_', 'a_', 'v_']
for attr in shader_inputs:
if not any(attr.name.startswith(prefix) for prefix in valid_prefixes):
self.warning(f"Attribute '{attr.name}' doesn't follow naming convention "
f"(expected prefix: {valid_prefixes})", shader_name)
def print_report(self):
"""Print validation report"""
errors = [e for e in self.errors if e.severity == 'ERROR']
warnings = [e for e in self.errors if e.severity == 'WARNING']
infos = [e for e in self.errors if e.severity == 'INFO']
print("\n" + "="*80)
print("SHADER PIPELINE VALIDATION REPORT")
print("="*80)
if errors:
print(f"\n🔴 ERRORS ({len(errors)}):")
for err in errors:
loc = f" [{err.shader_name}:{err.line_number}]" if err.shader_name else ""
print(f"{err.message}{loc}")
if warnings:
print(f"\n⚠️ WARNINGS ({len(warnings)}):")
for warn in warnings:
loc = f" [{warn.shader_name}:{warn.line_number}]" if warn.shader_name else ""
print(f"{warn.message}{loc}")
if infos and '--verbose' in sys.argv:
print(f"\n INFO ({len(infos)}):")
for info in infos:
loc = f" [{info.shader_name}:{info.line_number}]" if info.shader_name else ""
print(f" {info.message}{loc}")
print("\n" + "="*80)
print(f"Summary: {len(errors)} errors, {len(warnings)} warnings, {len(infos)} info")
print("="*80 + "\n")
return len(errors) == 0
def main():
validator = ShaderValidator()
# Example: Define expected vertex layout based on bgfx_graphics_backend.cpp
# This should match: Position, Normal, Tangent, TexCoord0, Color0
vertex_layout = [
VertexLayoutAttribute(0, 'float3', 'Position', 12),
VertexLayoutAttribute(1, 'float3', 'Normal', 12),
VertexLayoutAttribute(2, 'float3', 'Tangent', 12),
VertexLayoutAttribute(3, 'float2', 'TexCoord0', 8),
VertexLayoutAttribute(4, 'float3', 'Color0', 12),
]
# Expected vertex struct size: 3+3+3+2+3 = 14 floats = 56 bytes
vertex_struct_size = 56
# Example vertex shader (this would come from reading actual shader files)
example_vs = """
#version 450
layout (location = 0) in vec3 i_position;
layout (location = 1) in vec3 i_normal;
layout (location = 2) in vec3 i_tangent;
layout (location = 3) in vec2 i_texcoord_0;
layout (location = 0) out vec3 v_normal;
layout (location = 1) out vec2 v_texcoord;
void main() {
gl_Position = vec4(i_position, 1.0);
v_normal = i_normal;
v_texcoord = i_texcoord_0;
}
"""
example_fs = """
#version 450
layout (location = 0) in vec3 v_normal;
layout (location = 1) in vec2 v_texcoord;
layout (location = 0) out vec4 fragColor;
void main() {
fragColor = vec4(v_normal * 0.5 + 0.5, 1.0);
}
"""
# Validation pipeline
print("Running shader pipeline validation...")
# 1. Extract attributes
vs_inputs = validator.extract_shader_inputs(example_vs, "example_vs")
vs_outputs = validator.extract_shader_outputs(example_vs, "example_vs")
fs_inputs = validator.extract_shader_inputs(example_fs, "example_fs")
# 2. Validate vertex layout matching
validator.validate_vertex_layout_match(vs_inputs, vertex_layout, "example_vs")
# 3. Validate stride
validator.validate_vertex_buffer_stride(vertex_layout, vertex_struct_size)
# 4. Validate location continuity
validator.validate_location_continuity(vs_inputs, "example_vs")
# 5. Validate SPIR-V compatibility
validator.validate_spir_v_compatibility(example_vs, "example_vs")
validator.validate_spir_v_compatibility(example_fs, "example_fs")
# 6. Validate interface matching
validator.validate_interface_matching(vs_outputs, fs_inputs, "example_vs", "example_fs")
# 7. Validate naming conventions
validator.validate_attribute_name_conventions(vs_inputs, "example_vs")
# Print report
success = validator.print_report()
return 0 if success else 1
if __name__ == '__main__':
sys.exit(main())