7 #include <Metal/Metal.h>
10 #include "flutter/fml/backtrace.h"
11 #include "flutter/fml/closure.h"
12 #include "flutter/fml/logging.h"
13 #include "flutter/fml/trace_event.h"
14 #include "fml/status.h"
29 ComputePassMTL::ComputePassMTL(std::shared_ptr<const Context> context,
30 id<MTLCommandBuffer> buffer)
31 : ComputePass(
std::move(context)), buffer_(buffer) {
35 encoder_ = [buffer_ computeCommandEncoderWithDispatchType:
36 MTLDispatchType::MTLDispatchTypeConcurrent];
46 bool ComputePassMTL::IsValid()
const {
50 void ComputePassMTL::OnSetLabel(
const std::string& label) {
55 [encoder_ setLabel:@(label.c_str())];
56 #endif // IMPELLER_DEBUG
59 void ComputePassMTL::SetCommandLabel(std::string_view label) {
61 [encoder_ pushDebugGroup:@(label.data())];
65 void ComputePassMTL::SetPipeline(
66 const std::shared_ptr<Pipeline<ComputePipelineDescriptor>>& pipeline) {
72 void ComputePassMTL::AddBufferMemoryBarrier() {
73 [encoder_ memoryBarrierWithScope:MTLBarrierScopeBuffers];
77 void ComputePassMTL::AddTextureMemoryBarrier() {
78 [encoder_ memoryBarrierWithScope:MTLBarrierScopeTextures];
82 bool ComputePassMTL::BindResource(
ShaderStage stage,
84 const ShaderUniformSlot& slot,
85 const ShaderMetadata& metadata,
91 const std::shared_ptr<const DeviceBuffer>& device_buffer = view.buffer;
102 pass_bindings_cache_.
SetBuffer(slot.ext_res_0, view.range.offset, buffer);
107 bool ComputePassMTL::BindResource(
110 const SampledImageSlot& slot,
111 const ShaderMetadata& metadata,
112 std::shared_ptr<const Texture> texture,
113 const std::unique_ptr<const Sampler>& sampler) {
114 if (!sampler || !texture->IsValid()) {
118 pass_bindings_cache_.
SetTexture(slot.texture_index,
125 fml::Status ComputePassMTL::Compute(
const ISize& grid_size) {
126 if (grid_size.IsEmpty()) {
127 return fml::Status(fml::StatusCode::kUnknown,
128 "Invalid grid size for compute command.");
133 auto width = grid_size.width;
134 auto height = grid_size.height;
136 auto max_total_threads_per_threadgroup =
static_cast<int64_t
>(
137 pass_bindings_cache_.
GetPipeline().maxTotalThreadsPerThreadgroup);
141 int64_t thread_groups = std::max(
142 static_cast<int64_t
>(
143 std::ceil(width * 1.0 / max_total_threads_per_threadgroup * 1.0)),
146 dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1)
147 threadsPerThreadgroup:MTLSizeMake(max_total_threads_per_threadgroup, 1,
150 while (width * height > max_total_threads_per_threadgroup) {
151 width = std::max(1LL, width / 2);
152 height = std::max(1LL, height / 2);
155 auto size = MTLSizeMake(width, height, 1);
156 [encoder_ dispatchThreadgroups:size threadsPerThreadgroup:size];
159 #ifdef IMPELLER_DEBUG
161 [encoder_ popDebugGroup];
164 #endif // IMPELLER_DEBUG
165 return fml::Status();
168 bool ComputePassMTL::EncodeCommands()
const {
169 [encoder_ endEncoding];