From 329938885797127c6f7ba661fe2537486c62570e Mon Sep 17 00:00:00 2001 From: johndoe6345789 Date: Mon, 5 Jan 2026 22:38:55 +0000 Subject: [PATCH] feat: Implement shader byte loading and caching in GuiRenderer and PipelineService --- src/services/impl/gui_renderer.cpp | 27 ++++++++- src/services/impl/gui_renderer.hpp | 3 + src/services/impl/pipeline_service.cpp | 80 +++++++++++++++----------- src/services/impl/pipeline_service.hpp | 4 +- tests/scripts/unit_cube_logic.lua | 4 +- tests/test_cube_script.cpp | 4 +- 6 files changed, 82 insertions(+), 40 deletions(-) diff --git a/src/services/impl/gui_renderer.cpp b/src/services/impl/gui_renderer.cpp index cdec16c..f0a2557 100644 --- a/src/services/impl/gui_renderer.cpp +++ b/src/services/impl/gui_renderer.cpp @@ -667,6 +667,29 @@ void GuiRenderer::GenerateGuiGeometry(const std::vector& commands, u } } +const std::vector& GuiRenderer::LoadShaderBytes(const std::filesystem::path& path, + VkShaderStageFlagBits stage) { + const std::string key = path.string(); + auto cached = shaderSpirvCache_.find(key); + if (cached != shaderSpirvCache_.end()) { + if (logger_) { + logger_->Trace("GuiRenderer", "LoadShaderBytes", + "cacheHit=true, path=" + key + + ", bytes=" + std::to_string(cached->second.size())); + } + return cached->second; + } + + std::vector shaderBytes = ReadShaderFile(path, stage, logger_.get()); + auto inserted = shaderSpirvCache_.emplace(key, std::move(shaderBytes)); + if (logger_) { + logger_->Trace("GuiRenderer", "LoadShaderBytes", + "cacheHit=false, path=" + key + + ", bytes=" + std::to_string(inserted.first->second.size())); + } + return inserted.first->second; +} + void GuiRenderer::CreatePipeline(VkRenderPass renderPass, VkExtent2D extent) { // Load shader modules const std::filesystem::path vertexShaderPath = @@ -681,8 +704,8 @@ void GuiRenderer::CreatePipeline(VkRenderPass renderPass, VkExtent2D extent) { ", fragmentShader=" + fragmentShaderPath.string()); } - auto vertShaderCode = ReadShaderFile(vertexShaderPath, VK_SHADER_STAGE_VERTEX_BIT, logger_.get()); - auto fragShaderCode = ReadShaderFile(fragmentShaderPath, VK_SHADER_STAGE_FRAGMENT_BIT, logger_.get()); + const auto& vertShaderCode = LoadShaderBytes(vertexShaderPath, VK_SHADER_STAGE_VERTEX_BIT); + const auto& fragShaderCode = LoadShaderBytes(fragmentShaderPath, VK_SHADER_STAGE_FRAGMENT_BIT); VkShaderModuleCreateInfo vertModuleInfo{}; vertModuleInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; diff --git a/src/services/impl/gui_renderer.hpp b/src/services/impl/gui_renderer.hpp index 192d982..07c6d58 100644 --- a/src/services/impl/gui_renderer.hpp +++ b/src/services/impl/gui_renderer.hpp @@ -57,6 +57,8 @@ private: void CleanupBuffers(); void UpdateFormat(VkFormat format); void GenerateGuiGeometry(const std::vector& commands, uint32_t width, uint32_t height); + const std::vector& LoadShaderBytes(const std::filesystem::path& path, + VkShaderStageFlagBits stage); VkDevice device_; VkPhysicalDevice physicalDevice_; @@ -83,6 +85,7 @@ private: uint32_t viewportWidth_ = 0; uint32_t viewportHeight_ = 0; std::unordered_map svgCache_; + std::unordered_map> shaderSpirvCache_; std::shared_ptr bufferService_; std::shared_ptr logger_; }; diff --git a/src/services/impl/pipeline_service.cpp b/src/services/impl/pipeline_service.cpp index 454feba..83e5c89 100644 --- a/src/services/impl/pipeline_service.cpp +++ b/src/services/impl/pipeline_service.cpp @@ -94,6 +94,7 @@ void PipelineService::RecreatePipelines(VkRenderPass renderPass, VkExtent2D exte void PipelineService::Cleanup() { logger_->Trace("PipelineService", "Cleanup"); CleanupPipelines(); + shaderSpirvCache_.clear(); auto device = deviceService_->GetDevice(); @@ -306,7 +307,7 @@ void PipelineService::CreatePipelinesInternal(VkRenderPass renderPass, VkExtent2 }; auto addStage = [&](VkShaderStageFlagBits stage, const std::string& path) { - auto shaderCode = ReadShaderFile(path, stage); + const auto& shaderCode = ReadShaderFile(path, stage); VkShaderModule shaderModule = CreateShaderModule(shaderCode); shaderModules.push_back(shaderModule); @@ -407,7 +408,7 @@ bool PipelineService::HasShaderSource(const std::string& path) const { return false; } -std::vector PipelineService::ReadShaderFile(const std::string& path, VkShaderStageFlagBits stage) { +const std::vector& PipelineService::ReadShaderFile(const std::string& path, VkShaderStageFlagBits stage) { logger_->Trace("PipelineService", "ReadShaderFile", "path=" + path + ", stage=" + std::to_string(static_cast(stage))); @@ -435,6 +436,16 @@ std::vector PipelineService::ReadShaderFile(const std::string& path, VkSha throw std::runtime_error("Path is not a regular file: " + shaderPath.string()); } + const std::string cacheKey = shaderPath.string() + "|" + + std::to_string(static_cast(stage)); + auto cached = shaderSpirvCache_.find(cacheKey); + if (cached != shaderSpirvCache_.end()) { + logger_->Trace("PipelineService", "ReadShaderFile", + "cacheHit=true, bytes=" + std::to_string(cached->second.size())); + return cached->second; + } + + std::vector buffer; if (IsSpirvPath(shaderPath)) { std::ifstream file(shaderPath, std::ios::ate | std::ios::binary); if (!file) { @@ -443,7 +454,7 @@ std::vector PipelineService::ReadShaderFile(const std::string& path, VkSha } size_t fileSize = static_cast(file.tellg()); - std::vector buffer(fileSize); + buffer.resize(fileSize); file.seekg(0); file.read(buffer.data(), static_cast(fileSize)); @@ -451,38 +462,41 @@ std::vector PipelineService::ReadShaderFile(const std::string& path, VkSha logger_->Debug("Read shader file: " + shaderPath.string() + " (" + std::to_string(fileSize) + " bytes)"); - return buffer; + } else { + std::ifstream sourceFile(shaderPath); + if (!sourceFile) { + throw std::runtime_error("Failed to open shader source: " + shaderPath.string()); + } + std::string source((std::istreambuf_iterator(sourceFile)), + std::istreambuf_iterator()); + sourceFile.close(); + + shaderc::Compiler compiler; + shaderc::CompileOptions options; + options.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_2); + + shaderc_shader_kind kind = ShadercKindFromStage(stage); + auto result = compiler.CompileGlslToSpv(source, kind, shaderPath.string().c_str(), options); + if (result.GetCompilationStatus() != shaderc_compilation_status_success) { + std::string error = result.GetErrorMessage(); + logger_->Error("Shader compilation failed: " + shaderPath.string() + "\n" + error); + throw std::runtime_error("Shader compilation failed: " + shaderPath.string() + "\n" + error); + } + + std::vector spirv(result.cbegin(), result.cend()); + buffer.resize(spirv.size() * sizeof(uint32_t)); + if (!buffer.empty()) { + std::memcpy(buffer.data(), spirv.data(), buffer.size()); + } + + logger_->Debug("Compiled shader: " + shaderPath.string() + + " (" + std::to_string(buffer.size()) + " bytes)"); } - std::ifstream sourceFile(shaderPath); - if (!sourceFile) { - throw std::runtime_error("Failed to open shader source: " + shaderPath.string()); - } - std::string source((std::istreambuf_iterator(sourceFile)), - std::istreambuf_iterator()); - sourceFile.close(); - - shaderc::Compiler compiler; - shaderc::CompileOptions options; - options.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_2); - - shaderc_shader_kind kind = ShadercKindFromStage(stage); - auto result = compiler.CompileGlslToSpv(source, kind, shaderPath.string().c_str(), options); - if (result.GetCompilationStatus() != shaderc_compilation_status_success) { - std::string error = result.GetErrorMessage(); - logger_->Error("Shader compilation failed: " + shaderPath.string() + "\n" + error); - throw std::runtime_error("Shader compilation failed: " + shaderPath.string() + "\n" + error); - } - - std::vector spirv(result.cbegin(), result.cend()); - std::vector buffer(spirv.size() * sizeof(uint32_t)); - if (!buffer.empty()) { - std::memcpy(buffer.data(), spirv.data(), buffer.size()); - } - - logger_->Debug("Compiled shader: " + shaderPath.string() + - " (" + std::to_string(buffer.size()) + " bytes)"); - return buffer; + auto inserted = shaderSpirvCache_.emplace(cacheKey, std::move(buffer)); + logger_->Trace("PipelineService", "ReadShaderFile", + "cacheHit=false, bytes=" + std::to_string(inserted.first->second.size())); + return inserted.first->second; } } // namespace sdl3cpp::services::impl diff --git a/src/services/impl/pipeline_service.hpp b/src/services/impl/pipeline_service.hpp index ac478b3..6f297e9 100644 --- a/src/services/impl/pipeline_service.hpp +++ b/src/services/impl/pipeline_service.hpp @@ -52,11 +52,13 @@ private: // Helper methods VkShaderModule CreateShaderModule(const std::vector& code); - std::vector ReadShaderFile(const std::string& path, VkShaderStageFlagBits stage); + const std::vector& ReadShaderFile(const std::string& path, VkShaderStageFlagBits stage); bool HasShaderSource(const std::string& path) const; void CreatePipelineLayout(); void CreatePipelinesInternal(VkRenderPass renderPass, VkExtent2D extent); void CleanupPipelines(); + + std::unordered_map> shaderSpirvCache_; }; } // namespace sdl3cpp::services::impl diff --git a/tests/scripts/unit_cube_logic.lua b/tests/scripts/unit_cube_logic.lua index 17a304c..eb4cdfb 100644 --- a/tests/scripts/unit_cube_logic.lua +++ b/tests/scripts/unit_cube_logic.lua @@ -27,8 +27,8 @@ end function get_shader_paths() return { test = { - vertex = "shaders/test.vert.spv", - fragment = "shaders/test.frag.spv", + vertex = "shaders/test.vert", + fragment = "shaders/test.frag", }, } end diff --git a/tests/test_cube_script.cpp b/tests/test_cube_script.cpp index 8a3ae0a..5a35c3f 100644 --- a/tests/test_cube_script.cpp +++ b/tests/test_cube_script.cpp @@ -97,8 +97,8 @@ int main() { auto testEntry = shaderMap.find("test"); Assert(testEntry != shaderMap.end(), "shader map missing test entry", failures); if (testEntry != shaderMap.end()) { - Assert(testEntry->second.vertex == "shaders/test.vert.spv", "vertex shader path", failures); - Assert(testEntry->second.fragment == "shaders/test.frag.spv", "fragment shader path", failures); + Assert(testEntry->second.vertex == "shaders/test.vert", "vertex shader path", failures); + Assert(testEntry->second.fragment == "shaders/test.frag", "fragment shader path", failures); } } catch (const std::exception& ex) { std::cerr << "exception during tests: " << ex.what() << '\n';