Flutter Impeller
compute_pass_mtl.mm
Go to the documentation of this file.
1 // Copyright 2013 The Flutter Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
6 
7 #include <Metal/Metal.h>
8 #include <memory>
9 
10 #include "flutter/fml/backtrace.h"
11 #include "flutter/fml/closure.h"
12 #include "flutter/fml/logging.h"
13 #include "flutter/fml/trace_event.h"
14 #include "fml/status.h"
17 #include "impeller/core/formats.h"
26 
27 namespace impeller {
28 
29 ComputePassMTL::ComputePassMTL(std::shared_ptr<const Context> context,
30  id<MTLCommandBuffer> buffer)
31  : ComputePass(std::move(context)), buffer_(buffer) {
32  if (!buffer_) {
33  return;
34  }
35  encoder_ = [buffer_ computeCommandEncoderWithDispatchType:
36  MTLDispatchType::MTLDispatchTypeConcurrent];
37  if (!encoder_) {
38  return;
39  }
40  pass_bindings_cache_.SetEncoder(encoder_);
41  is_valid_ = true;
42 }
43 
44 ComputePassMTL::~ComputePassMTL() = default;
45 
46 bool ComputePassMTL::IsValid() const {
47  return is_valid_;
48 }
49 
50 void ComputePassMTL::OnSetLabel(const std::string& label) {
51 #ifdef IMPELLER_DEBUG
52  if (label.empty()) {
53  return;
54  }
55  [encoder_ setLabel:@(label.c_str())];
56 #endif // IMPELLER_DEBUG
57 }
58 
59 void ComputePassMTL::SetCommandLabel(std::string_view label) {
60  has_label_ = true;
61  [encoder_ pushDebugGroup:@(label.data())];
62 }
63 
64 // |ComputePass|
65 void ComputePassMTL::SetPipeline(
66  const std::shared_ptr<Pipeline<ComputePipelineDescriptor>>& pipeline) {
67  pass_bindings_cache_.SetComputePipelineState(
68  ComputePipelineMTL::Cast(*pipeline).GetMTLComputePipelineState());
69 }
70 
71 // |ComputePass|
72 void ComputePassMTL::AddBufferMemoryBarrier() {
73  [encoder_ memoryBarrierWithScope:MTLBarrierScopeBuffers];
74 }
75 
76 // |ComputePass|
77 void ComputePassMTL::AddTextureMemoryBarrier() {
78  [encoder_ memoryBarrierWithScope:MTLBarrierScopeTextures];
79 }
80 
81 // |ComputePass|
82 bool ComputePassMTL::BindResource(ShaderStage stage,
84  const ShaderUniformSlot& slot,
85  const ShaderMetadata* metadata,
86  BufferView view) {
87  if (!view.GetBuffer()) {
88  return false;
89  }
90 
91  const DeviceBuffer* device_buffer = view.GetBuffer();
92  if (!device_buffer) {
93  return false;
94  }
95 
96  id<MTLBuffer> buffer = DeviceBufferMTL::Cast(*device_buffer).GetMTLBuffer();
97  // The Metal call is a void return and we don't want to make it on nil.
98  if (!buffer) {
99  return false;
100  }
101 
102  pass_bindings_cache_.SetBuffer(slot.ext_res_0, view.GetRange().offset,
103  buffer);
104  return true;
105 }
106 
107 // |ComputePass|
108 bool ComputePassMTL::BindResource(ShaderStage stage,
110  const SampledImageSlot& slot,
111  const ShaderMetadata* metadata,
112  std::shared_ptr<const Texture> texture,
113  raw_ptr<const Sampler> sampler) {
114  if (!sampler || !texture->IsValid()) {
115  return false;
116  }
117 
118  pass_bindings_cache_.SetTexture(slot.texture_index,
119  TextureMTL::Cast(*texture).GetMTLTexture());
120  pass_bindings_cache_.SetSampler(
121  slot.texture_index, SamplerMTL::Cast(*sampler).GetMTLSamplerState());
122  return true;
123 }
124 
125 fml::Status ComputePassMTL::Compute(const ISize& grid_size) {
126  if (grid_size.IsEmpty()) {
127  return fml::Status(fml::StatusCode::kUnknown,
128  "Invalid grid size for compute command.");
129  }
130 
131  // Threadgroup sizes must be uniform.
132  auto width = grid_size.width;
133  auto height = grid_size.height;
134 
135  auto max_total_threads_per_threadgroup = static_cast<int64_t>(
136  pass_bindings_cache_.GetPipeline().maxTotalThreadsPerThreadgroup);
137 
138  // Special case for linear processing.
139  if (height == 1) {
140  int64_t thread_groups = std::max(
141  static_cast<int64_t>(
142  std::ceil(width * 1.0 / max_total_threads_per_threadgroup * 1.0)),
143  1LL);
144  [encoder_
145  dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1)
146  threadsPerThreadgroup:MTLSizeMake(max_total_threads_per_threadgroup, 1,
147  1)];
148  } else {
149  while (width * height > max_total_threads_per_threadgroup) {
150  width = std::max(1LL, width / 2);
151  height = std::max(1LL, height / 2);
152  }
153 
154  auto size = MTLSizeMake(width, height, 1);
155  [encoder_ dispatchThreadgroups:size threadsPerThreadgroup:size];
156  }
157 
158 #ifdef IMPELLER_DEBUG
159  if (has_label_) {
160  [encoder_ popDebugGroup];
161  }
162  has_label_ = false;
163 #endif // IMPELLER_DEBUG
164  return fml::Status();
165 }
166 
167 bool ComputePassMTL::EncodeCommands() const {
168  [encoder_ endEncoding];
169  return true;
170 }
171 
172 } // namespace impeller
GLenum type
ISize64 ISize
Definition: size.h:162
Definition: comparable.h:95