7 #include <Metal/Metal.h>
10 #include "flutter/fml/backtrace.h"
11 #include "flutter/fml/closure.h"
12 #include "flutter/fml/logging.h"
13 #include "flutter/fml/trace_event.h"
14 #include "fml/status.h"
29 ComputePassMTL::ComputePassMTL(std::shared_ptr<const Context> context,
30 id<MTLCommandBuffer> buffer)
31 : ComputePass(
std::move(context)), buffer_(buffer) {
35 encoder_ = [buffer_ computeCommandEncoderWithDispatchType:
36 MTLDispatchType::MTLDispatchTypeConcurrent];
40 pass_bindings_cache_.SetEncoder(encoder_);
44 ComputePassMTL::~ComputePassMTL() =
default;
46 bool ComputePassMTL::IsValid()
const {
50 void ComputePassMTL::OnSetLabel(
const std::string& label) {
55 [encoder_ setLabel:@(label.c_str())];
59 void ComputePassMTL::SetCommandLabel(std::string_view label) {
61 [encoder_ pushDebugGroup:@(label.data())];
65 void ComputePassMTL::SetPipeline(
66 const std::shared_ptr<Pipeline<ComputePipelineDescriptor>>& pipeline) {
67 pass_bindings_cache_.SetComputePipelineState(
68 ComputePipelineMTL::Cast(*pipeline).GetMTLComputePipelineState());
72 void ComputePassMTL::AddBufferMemoryBarrier() {
73 [encoder_ memoryBarrierWithScope:MTLBarrierScopeBuffers];
77 void ComputePassMTL::AddTextureMemoryBarrier() {
78 [encoder_ memoryBarrierWithScope:MTLBarrierScopeTextures];
82 bool ComputePassMTL::BindResource(
ShaderStage stage,
84 const ShaderUniformSlot& slot,
85 const ShaderMetadata* metadata,
87 if (!view.GetBuffer()) {
91 const DeviceBuffer* device_buffer = view.GetBuffer();
96 id<MTLBuffer> buffer = DeviceBufferMTL::Cast(*device_buffer).GetMTLBuffer();
102 pass_bindings_cache_.SetBuffer(slot.ext_res_0, view.GetRange().offset,
108 bool ComputePassMTL::BindResource(
ShaderStage stage,
110 const SampledImageSlot& slot,
111 const ShaderMetadata* metadata,
112 std::shared_ptr<const Texture> texture,
113 raw_ptr<const Sampler> sampler) {
114 if (!sampler || !texture->IsValid()) {
118 pass_bindings_cache_.SetTexture(slot.texture_index,
119 TextureMTL::Cast(*texture).GetMTLTexture());
120 pass_bindings_cache_.SetSampler(
121 slot.texture_index, SamplerMTL::Cast(*sampler).GetMTLSamplerState());
125 fml::Status ComputePassMTL::Compute(
const ISize& grid_size) {
126 if (grid_size.IsEmpty()) {
127 return fml::Status(fml::StatusCode::kUnknown,
128 "Invalid grid size for compute command.");
132 auto width = grid_size.width;
133 auto height = grid_size.height;
135 auto max_total_threads_per_threadgroup =
static_cast<int64_t
>(
136 pass_bindings_cache_.GetPipeline().maxTotalThreadsPerThreadgroup);
140 int64_t thread_groups = std::max(
141 static_cast<int64_t
>(
142 std::ceil(width * 1.0 / max_total_threads_per_threadgroup * 1.0)),
145 dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1)
146 threadsPerThreadgroup:MTLSizeMake(max_total_threads_per_threadgroup, 1,
149 while (width * height > max_total_threads_per_threadgroup) {
150 width = std::max(1LL, width / 2);
151 height = std::max(1LL, height / 2);
154 auto size = MTLSizeMake(width, height, 1);
155 [encoder_ dispatchThreadgroups:size threadsPerThreadgroup:size];
158 #ifdef IMPELLER_DEBUG
160 [encoder_ popDebugGroup];
164 return fml::Status();
167 bool ComputePassMTL::EncodeCommands()
const {
168 [encoder_ endEncoding];