12 #include "vulkan/vulkan_structs.hpp"
16 ComputePassVK::ComputePassVK(std::shared_ptr<const Context> context,
17 std::shared_ptr<CommandBufferVK> command_buffer)
18 : ComputePass(
std::move(context)),
19 command_buffer_(
std::move(command_buffer)) {
22 max_wg_size_ = ContextVK::Cast(*context_)
25 .limits.maxComputeWorkGroupSize;
29 ComputePassVK::~ComputePassVK() =
default;
31 bool ComputePassVK::IsValid()
const {
35 void ComputePassVK::OnSetLabel(
const std::string& label) {
43 void ComputePassVK::SetCommandLabel(std::string_view label) {
45 command_buffer_->PushDebugGroup(label);
51 void ComputePassVK::SetPipeline(
52 const std::shared_ptr<Pipeline<ComputePipelineDescriptor>>& pipeline) {
53 const auto& pipeline_vk = ComputePipelineVK::Cast(*pipeline);
54 const vk::CommandBuffer& command_buffer_vk =
55 command_buffer_->GetCommandBuffer();
56 command_buffer_vk.bindPipeline(vk::PipelineBindPoint::eCompute,
57 pipeline_vk.GetPipeline());
58 pipeline_layout_ = pipeline_vk.GetPipelineLayout();
60 auto descriptor_result = command_buffer_->AllocateDescriptorSets(
61 pipeline_vk.GetDescriptorSetLayout(), pipeline_vk.GetPipelineKey(),
62 ContextVK::Cast(*context_));
63 if (!descriptor_result.ok()) {
66 descriptor_set_ = descriptor_result.value();
67 pipeline_valid_ =
true;
71 fml::Status ComputePassVK::Compute(
const ISize& grid_size) {
72 if (grid_size.IsEmpty() || !pipeline_valid_) {
73 bound_image_offset_ = 0u;
74 bound_buffer_offset_ = 0u;
75 descriptor_write_offset_ = 0u;
77 pipeline_valid_ =
false;
78 return fml::Status(fml::StatusCode::kCancelled,
79 "Invalid pipeline or empty grid.");
82 const ContextVK& context_vk = ContextVK::Cast(*context_);
83 for (
auto i = 0u; i < descriptor_write_offset_; i++) {
84 write_workspace_[i].dstSet = descriptor_set_;
87 context_vk.GetDevice().updateDescriptorSets(descriptor_write_offset_,
88 write_workspace_.data(), 0u, {});
89 const vk::CommandBuffer& command_buffer_vk =
90 command_buffer_->GetCommandBuffer();
92 command_buffer_vk.bindDescriptorSets(
93 vk::PipelineBindPoint::eCompute,
102 int64_t width = grid_size.width;
103 int64_t height = grid_size.height;
107 command_buffer_vk.dispatch(width, 1, 1);
109 while (width > max_wg_size_[0]) {
110 width = std::max(
static_cast<int64_t
>(1), width / 2);
112 while (height > max_wg_size_[1]) {
113 height = std::max(
static_cast<int64_t
>(1), height / 2);
115 command_buffer_vk.dispatch(width, height, 1);
118 #ifdef IMPELLER_DEBUG
120 command_buffer_->PopDebugGroup();
125 bound_image_offset_ = 0u;
126 bound_buffer_offset_ = 0u;
127 descriptor_write_offset_ = 0u;
129 pipeline_valid_ =
false;
131 return fml::Status();
135 bool ComputePassVK::BindResource(
ShaderStage stage,
137 const ShaderUniformSlot& slot,
138 const ShaderMetadata* metadata,
140 return BindResource(slot.binding,
type, view);
144 bool ComputePassVK::BindResource(
ShaderStage stage,
146 const SampledImageSlot& slot,
147 const ShaderMetadata* metadata,
148 std::shared_ptr<const Texture> texture,
149 raw_ptr<const Sampler> sampler) {
153 if (!texture->IsValid() || !sampler) {
156 const TextureVK& texture_vk = TextureVK::Cast(*texture);
157 const SamplerVK& sampler_vk = SamplerVK::Cast(*sampler);
159 if (!command_buffer_->Track(texture)) {
163 vk::DescriptorImageInfo image_info;
164 image_info.imageLayout = vk::ImageLayout::eShaderReadOnlyOptimal;
165 image_info.sampler = sampler_vk.GetSampler();
166 image_info.imageView = texture_vk.GetImageView();
167 image_workspace_[bound_image_offset_++] = image_info;
169 vk::WriteDescriptorSet write_set;
170 write_set.dstBinding = slot.binding;
171 write_set.descriptorCount = 1u;
173 write_set.pImageInfo = &image_workspace_[bound_image_offset_ - 1];
175 write_workspace_[descriptor_write_offset_++] = write_set;
179 bool ComputePassVK::BindResource(
size_t binding,
186 auto buffer = DeviceBufferVK::Cast(*view.GetBuffer()).GetBuffer();
191 std::shared_ptr<const DeviceBuffer> device_buffer = view.TakeBuffer();
192 if (device_buffer && !command_buffer_->Track(device_buffer)) {
196 uint32_t offset = view.GetRange().offset;
198 vk::DescriptorBufferInfo buffer_info;
199 buffer_info.buffer = buffer;
200 buffer_info.offset = offset;
201 buffer_info.range = view.GetRange().length;
202 buffer_workspace_[bound_buffer_offset_++] = buffer_info;
204 vk::WriteDescriptorSet write_set;
205 write_set.dstBinding = binding;
206 write_set.descriptorCount = 1u;
208 write_set.pBufferInfo = &buffer_workspace_[bound_buffer_offset_ - 1];
210 write_workspace_[descriptor_write_offset_++] = write_set;
221 void ComputePassVK::AddBufferMemoryBarrier() {
222 vk::MemoryBarrier barrier;
223 barrier.srcAccessMask = vk::AccessFlagBits::eShaderWrite;
224 barrier.dstAccessMask = vk::AccessFlagBits::eShaderRead;
226 command_buffer_->GetCommandBuffer().pipelineBarrier(
227 vk::PipelineStageFlagBits::eComputeShader,
228 vk::PipelineStageFlagBits::eComputeShader, {}, 1, &barrier, 0, {}, 0, {});
232 void ComputePassVK::AddTextureMemoryBarrier() {
233 vk::MemoryBarrier barrier;
234 barrier.srcAccessMask = vk::AccessFlagBits::eShaderWrite;
235 barrier.dstAccessMask = vk::AccessFlagBits::eShaderRead;
237 command_buffer_->GetCommandBuffer().pipelineBarrier(
238 vk::PipelineStageFlagBits::eComputeShader,
239 vk::PipelineStageFlagBits::eComputeShader, {}, 1, &barrier, 0, {}, 0, {});
243 bool ComputePassVK::EncodeCommands()
const {
252 vk::MemoryBarrier barrier;
253 barrier.srcAccessMask = vk::AccessFlagBits::eShaderWrite;
254 barrier.dstAccessMask =
255 vk::AccessFlagBits::eIndexRead | vk::AccessFlagBits::eVertexAttributeRead;
257 command_buffer_->GetCommandBuffer().pipelineBarrier(
258 vk::PipelineStageFlagBits::eComputeShader,
259 vk::PipelineStageFlagBits::eVertexInput, {}, 1, &barrier, 0, {}, 0, {});
constexpr vk::DescriptorType ToVKDescriptorType(DescriptorType type)
static constexpr size_t kMaxBindings