浅析PyTorch Dispatcher源码

本文将先介绍PyTorch dispatcher实现原理,然后再浅析v2.7.1版实现源码,希望能对大家有所帮助。 在继续阅读下文之前,建议先阅读PyTorch核心开发者Edward Z. Yang的Let’s talk about the PyTorch dispatcher以对dispatcher核心原理有所认识, 本文原理部分也主要参考了这篇文章。

1. Dispatcher问题背景与实现原理

PyTorch目前可以在CPU/CUDA/FPGA等多种后端设备上支持上千个算子,并且随着时间发展,PyTorch还需要不断支持新设备和新算子。 因此,PyTorch在代码架构设计上需要具有极强的可扩展性,而这目前是通过dispatcher来实现的。 Dispatcher在PyTorch代码架构中具有核心的作用,用Edward Z. Yang的话来说就是it is a really important abstraction for how we structure our code internally.

Dispatcher要解决的需要同时支持新设备和算子的扩展问题在本质上是expression problem, 目前其并没有较好的通用解法。 例如,在编译器领域,该问题为如何高效地给下图所示的表同时新增行和列。

2025-07-06_13-53-17_table.png

面向对象编程可以较好地解决新增行的问题,而函数式编程则可以较好地支持新增列(面向对象编程可以用visitor设计模式来模拟之)。 但是,目前并没有很好的办法来支持同时新增行和列。

在PyTorch中,本问题的表行索引为算子,列索引为DispatchKey (描述CPU/CUDA等设备后端,和Autograd/Tracing等PyTorch功能),表格单元则为KernelFunction(某个算子为某个DispatchKey提供的功能实现),如下图所示:

2025-07-06_14-50-49_slide-14.png

Dispatcher对本问题的解法如下:

  1. 首先,完全解耦合掉本问题的行(算子)和列(DispatchKey), 其基础管理单元为表中的单元格(KernelFunction)。
  2. 然后,通过一个控制中枢来管理表中的所有KernelFunction,并在运行时综合考虑算子入参和thread local变量等各个因素(用DispatchKey来描述)来选取出一个KernelFunction来执行,如下图所示:

    2025-07-06_14-53-44_slide-06.png

    “综合考虑各种因素后再选取一个KernelFunction”本质上还解决了另外一个叫做multiple dispatch的问题。 C++虚函数仅支持single-dispatch, 因此用C++解决本问题时需要使用一些技巧,具体可以参考More Effective C++ Item 31: Making functions virtual with respect to more than one object. 除了上述的根据DispatchKey来dispatch外,KernelFunction内部还可能会根据tensor dtype来进一步dispatch, 其实现为简单的switch语句,本文将不讨论之。

因此,dispatcher实现的关键在于如何描述上述KernelFucntion表,以及如何描述和综合被用于选取表中KernelFunction的因素(DispatchKey),下一小节将根据源码来分析其具体实现。

2. 源码分析

2.1. DispatchKey和DispatchKeySet

对于某个算子,其不止要适配CPU/CUDA/FPGA等设备后端,还需要适配Autograd/Tracing等PyTorch功能。 此外,Autograd等PyTorch功能需要对CPU/CUDA等可功能定制后端进行适配。 例如,Autograd的设备后端定制版包括AutogradCPU/AutogradCUDA/AutogradHIP等。

DispatchKey的功能为表达这些KernelFunction选取因素,而DispatchKeySet的功能则为综合多个DispatchKey并根据合成结果计算出KernelFunction表的列索引​,下面将对两者进行详解。

2.1.1. DispatchKey

Dispatcher使用DispatchKey来表达这些KernelFunction选取因素的,其在实现上是一个uint16t枚举类,每一个KernelFunction选取因子都对于一个枚举项,如下:

enum class DispatchKey : uint16_t {
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
  Undefined = 0,
  CatchAll = Undefined,

  // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ //
  Dense,
  FPGA,
  ...
  EndOfFunctionalityKeys, // End of functionality keys.

  // ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ //
#define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n,
#define DEFINE_PER_BACKEND_KEYS(fullname, prefix)      \
    StartOf##fullname##Backends,                         \
        C10_FORALL_BACKEND_COMPONENTS(                   \
            DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \
            EndOf##fullname##Backends = prefix##Meta,
    C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS)
#undef DEFINE_PER_BACKEND_KEYS
#undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND

  EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends,

  // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ //
  // Alias dispatch keys are synthetic dispatch keys which map to multiple runtime dispatch keys.
  Autograd,
  CompositeImplicitAutograd,
  ...
  StartOfAliasKeys = Autograd,
  EndOfAliasKeys = CompositeExplicitAutogradNonFunctional,

  // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
  // The aliases exist for backwards compatibility reasons, they shouldn't be used
  CPUTensorId = CPU,
  CUDATensorId = CUDA,
  DefaultBackend = CompositeExplicitAutograd,
  PrivateUse1_PreAutograd = AutogradPrivateUse1,
  PrivateUse2_PreAutograd = AutogradPrivateUse2,
  PrivateUse3_PreAutograd = AutogradPrivateUse3,
  Autocast = AutocastCUDA,
};

可以看出,DispatchKey可以分为Functionality / Per-Backend / Alias / BC (backwards compatibility)这几类。 另外,为了可维护性,代码使用​C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS)​宏来生成Autograd等PyTorch功能对设备后端适配的结果对应的DispatchKey枚举项。

2.1.2. DispatchKeySet

DispatchKeySet的功能为​综合多个DispatchKey (实现了multiple-dispatch)​, 并根据合成结果计算出​KernelFunction表的列索引​。 DispatchKey自身无法较好地完成这些任务,原因如下:

  1. Alias和Per-Backend Functionality等DispatchKey枚举项不会对应到KernelFunction表的列索引。
  2. Per-Backend DispatchKey枚举项之间的信息具有冗余(比如AutogradCPU和AutogradCUDA都对应Autograd),因此在融合这些枚举项前需要拆分出其functionality和backend信息。

DispatchKeySet可以从DispathKey中提取出functionality和backend信息,因此多个DispatchKeySet之间可以之间使用位运算来融合。 DispatchKeySet用一个uint64t数据成员存储提取处理来的functionality和backend信息,如下图所示:

dispatch_key_set.png在dispatcher选取KernelFunction时,functionality和backend之间是有优先级的,DispatchKeySet在高bit位存储具有高优先级的functionality和backend​,如上图所示。

DispatchKey与DispatchKeySet bit位以及KernelFunction表列索引之间的关系如下:

DispatchKey类型 是否为一种DispatchKeySet bit位 是否对应一个KernelFunction表列索引
可定制的backend (BackendComponent中所有DispatchKey)
Per-Backend Functionality DispatchKey自身
Per-Backend Functionality DispatchKey定制实例 否(但可以转化为两个bit为1的DispatchKeySet)
不可自定义的后端
不可自定义的功能

DispatchKeySet的代码实现如下:

// A representation of a set of DispatchKeys. A DispatchKeySet contains both "functionality" bits and "backend bits", and every tensor holds its own DispatchKeySet.
class DispatchKeySet final {
 public:
  // 从DispatchKey中提取functionality和backend信息
  constexpr explicit DispatchKeySet(DispatchKey k) { ... }

  // 融合多个DispatchKeySet以实现multiple-dispatch
  constexpr DispatchKeySet operator|(DispatchKeySet other) const {
    return DispatchKeySet(repr_ | other.repr_);
  }

  // 计算DispatchKeySet对应的KernelFunction表列索引
  int getDispatchTableIndexForDispatchKeySet() const {
    auto functionality_idx =
        DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
    auto offset_and_mask = offsetsAndMasks()[functionality_idx];
    auto backend_idx =
        DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
    return offset_and_mask.offset + backend_idx;
  }

 private:
  constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
  uint64_t repr_ = 0;
};

其中,​DispatchKeySet(DispatchKey k)​构造函数被用于从DispatchKey中提取functionality和backend信息,而​operator|(DispatchKeySet other)​等运算符重载函数则被用于融合多个DispatchKeySet. DispatchKeySet的最终目的是计算KernelFunction表列索引,其由​getDispatchTableIndexForDispatchKeySet()​来实现。 在本人极简编译的PyTorch v2.7.1版中,KernelFunction表一共有132列。

getDispatchTableIndexForDispatchKeySet()​依赖的​offsetsAndMasks()​数组的计算方法如下:

std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks() {
  std::array<FunctionalityOffsetAndMask, num_functionality_keys>
      offsets_and_masks;
  // manually set the first entry, which corresponds to Undefined.
  offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
  // loop through every functionality key (aside from Undefined).
  for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
    // functionality_idx should be Dense -> 1, ...
    auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
    auto k = static_cast<DispatchKey>(functionality_idx);

    // If the previous functionality was not per-backend, then we can just
    // increment the previous offset. Otherwise, the next offset =
    // previous_offset + num_backends.
    auto next_offset = prev_offset_and_mask.offset +
        (prev_offset_and_mask.mask == 0 ? 1 : num_backends);
    // the mask is used in the runtime index calculation to find the offset of
    // the backend. For non-per-backend functionalities, this offset should
    // always be 0. Otherwise, we need to get the index of the backend (which we
    // can do using a backend mask).
    auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
    offsets_and_masks[functionality_idx] =
        FunctionalityOffsetAndMask(next_offset, next_mask);
  }
  return offsets_and_masks;
}

再结合结​getDispatchTableIndexForDispatchKeySet()​代码可以看出,一个普通functionality对应KernelFunction表一列,而一个per-backend functionality则对应numbackends列

2.2. KernelFunction表与算子注册

2.2.1. KernelFunction表相关数据结构

KernelFunction表是一个二维表,行索引为算子,列索引为DispatchKey. PyTorth将KernelFunction表存储于Dispatcher::operators和Dispatcher::operatorLookupTable,如下:

class TORCH_API Dispatcher final {
 private:
  std::list<OperatorDef> operators_;
  LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>>
      operatorLookupTable_;
};

其中,operators中的OperatorDef存储了每个算子对应所有KernelFunction, operatorLookupTable表则将可以根据算子名查找到OperatorHandle, 而OperatorHandle本质上是指向OperatorDef的handle, 其定义如下:

class TORCH_API OperatorHandle {
 private:
  Dispatcher::OperatorDef* operatorDef_;
  std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
};

OperatorDef的定义如下:

struct OperatorDef final {
  explicit OperatorDef(OperatorName&& op_name) : op(std::move(op_name)) {}
  impl::OperatorEntry op;
  size_t def_count = 0;
  size_t def_and_impl_count = 0;
};

从中可以看出,OperatorDef的主要功能是给OperatorEntry提供引用计数,而OperatorEntry才是真正的实体, 其定义如下:

class TORCH_API OperatorEntry final {
 private:
  OperatorName name_;
  std::optional<AnnotatedSchema> schema_;
  std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;

  DispatchKeyExtractor dispatchKeyExtractor_;

  ska::flat_hash_map<
      DispatchKey,
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
      // On mobile, we needn't worry about Jupyter notebooks.
      std::array<AnnotatedKernel, 1>
#else
      std::list<AnnotatedKernel>
#endif
      >
      kernels_;
};

其中,dispatchTable数组存储了一个算子对应的所有KernelFunction (某个DispatchKey对应的数组索引可以使用根据​DispatchKeySet::getDispatchTableIndexForDispatchKeySet()​来计算), 而dispatchKeyExtractor则是算子专有的DispatchKey提取器(每个算子的DispatchKey提取规则各不相同),而kernels_ map则被用于记录当前算子为每一个DispatchKey注册过的KernelFunction (一个DispatchKey可以被注册多个KernelFunction, 但仅最新注册的有效)。

2.2.2. 算子注册

PyTorch算子注册方式包括TORCHLIBRARY API和RegisterOperators两种,官方更推荐TORCHLIBRARY API, 两者最终都通过调用​Dispatcher::registerDef()​和​Dispatcher::registerImpl()​来实现。 ​Dispatcher::registerDef()​的功能为注册算子定义,其代码实现如下:

RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) {
  ...
  OperatorName op_name = schema.operator_name();
  auto op = findOrRegisterName_(op_name);
  ...
}

OperatorHandle Dispatcher::findOrRegisterName_(const OperatorName& op_name) {
  ...
  operators_.emplace_back(OperatorName(op_name));
  OperatorHandle handle(--operators_.end());
  operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
    operatorLookupTable.emplace(op_name, handle);
  });
  return handle;
}

可以看出,其主要功能为给在​Dispatcher::operators_​ list中新增一项OperatorDef, 同时还更新了算子查找表operatorLookupTable_. ​Dispatcyher::registerImpl()​的功能为给算子注册KernelFunction实现,其代码实现如下:

RegistrationHandleRAII Dispatcher::registerImpl(
  OperatorName op_name,
  std::optional<DispatchKey> dispatch_key,
  KernelFunction kernel,
  std::optional<impl::CppSignature> cpp_signature,
  std::unique_ptr<FunctionSchema> inferred_function_schema,
  std::string debug
) {
  ...
  auto handle = op.operatorDef_->op.registerKernel(
    *this,
    dispatch_key,
    std::move(kernel),
    std::move(cpp_signature),
    std::move(inferred_function_schema),
    std::move(debug)
  );
  ...
}

可以看出,其主要通过调用OperatorEntry::registerKernel()来注册KernelFunction, 其实现如下:

OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
  const c10::Dispatcher& dispatcher,
  std::optional<DispatchKey> dispatch_key,
  KernelFunction kernel,
  std::optional<CppSignature> cpp_signature,
  std::unique_ptr<FunctionSchema> inferred_function_schema,
  std::string debug
) {
  ...
  auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
  k[0].kernel = std::move(kernel);
  k[0].inferred_function_schema = std::move(inferred_function_schema);
  k[0].debug = std::move(debug);
#else
  k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
#endif
  AnnotatedKernelContainerIterator inserted = k.begin();
  // update the dispatch table, i.e. re-establish the invariant
  // that the dispatch table points to the newest kernel
  if (dispatch_key.has_value()) {
    updateDispatchTable_(dispatcher, *dispatch_key);
  } else {
    updateDispatchTableFull_(dispatcher);
  }
  return inserted;
}

从中可以看出,其先在​OperatorEntry::kernels_​ map中DispatchKey的对应列表中插入一项AnnotatedKernel, 然后再调用​OperatorEntry::updateDispatchTable_()​来更新​OperatorEntry::dispatchTable_​, 如下:

void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
  ...
  for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
    updateDispatchTableEntry_(dispatcher, k);
  }
  ...
}

void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
  const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
  if (C10_UNLIKELY(dispatch_ix == -1)) {
    return;
  }
  dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
  dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
}

综上所述,注册算子定义的主要工作为在​Dispatcher::operators_​ list中新增一项OperatorDef, 而注册算子KernelFunction实现的主要工作为给该OperatorDef的​OperatorEntry::kernels_​ map中相应value列表中插入一个AnnotatedKernel后再更新​OperatorEntry::dispatchTable_​。

2.3. 运行时链路

PyTorch算子在执行时首先会根据算子名和schema等元数据(KernelFunction表行索引)在​Dispatcher::dispatchTable_​查找OperatorHandle, 然后再根据各处DispatchKey来计算DispatchKeySet以最终得到KernelFunction表列索引,最后使用该列索引在OperatorHandle对应的OperatorEntry中找到KernelFunction并调用之。 需要注意的时,​KernelFunction内的逻辑可能会重复上述流程以执行该算子的其它KernelFunction​。 例如,算子可能会先执行autograd相关逻辑,然后再执行算子对应的数学计算KernelFunction,而这些逻辑对应于多个不同的KernelFunction, 但它们都具有相同的执行流程。

接下来本节将以Tensor.add()为例分析Dispatcher运行时链路,其调用栈整体如下图所示:

2025-07-06_21-46-21_screenshot.png

当使用PyTorch Python前端执行​tensor1 + tensor2​后,经过pybind后将调用到C++侧THPVariableadd(), 如下:

static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  ...
  auto dispatch_add = [](const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) -> at::Tensor {
    pybind11::gil_scoped_release no_gil;
    return self.add(other, alpha);
  };
  ...
}

之后将执行at::Tensor::add(), 如下:

inline at::Tensor Tensor::add(const at::Tensor & other, const at::Scalar & alpha) const {
    return at::_ops::add_Tensor::call(const_cast<Tensor&>(*this), other, alpha);
}

at::Tensor add_Tensor::call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
    static auto op = create_add_Tensor_typed_handle();
    return op.call(self, other, alpha);
}

at::Tensor::add()之间透穿调用到addTensor::call(), 其会先查找add算子对应的OperatorHandle, 然后再执行查找结果call()方法。

2.3.1. 查找OperatorHandle

createaddTensortypedhandle()的实现如下:

static C10_NOINLINE c10::TypedOperatorHandle<add_Tensor::schema> create_add_Tensor_typed_handle() {
  return c10::Dispatcher::singleton()
      .findSchemaOrThrow(add_Tensor::name, add_Tensor::overload_name)
      .typed<add_Tensor::schema>();
}

可以看出,其本质上调用Dispatcher::findSchemaOrThrow()来完成任务,如下:

OperatorHandle Dispatcher::findSchemaOrThrow(const char* name, const char* overload_name) {
  auto it = findSchema({name, overload_name});
  ...
  return it.value();
}

std::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overload_name) {
  auto it = findOp(overload_name);
  ..
}

可以看出,findSchemaOrThrow()最终将调用到findOp(), 如下:

std::optional<OperatorHandle> Dispatcher::findOp(const OperatorName& overload_name) {
  return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::optional<OperatorHandle> {
    auto found = operatorLookupTable.find(overload_name);
    if (found == operatorLookupTable.end()) {
      return std::nullopt;
    }
    return found->second;
  });
}

findOp()本质上是在根据OperatorName来查前述Dispatcher::operatorLookupTable_ map. 对于add算子,最终查找结果为OperatorHandle的子类TypedOperatorHandle.

2.3.2. 执行OperatorHandle

addTensor::call()在查找到OperatorHandle后,会执行TypedOperatorHandle::call()方法,如下:

template <class Return, class... Args>
class TypedOperatorHandle<Return(Args...)> final : public OperatorHandle {
 public:
  C10_ALWAYS_INLINE Return call(Args... args) const {
    return c10::Dispatcher::singleton().call<Return, Args...>(
        ,*this, std::forward<Args>(args)...);
  }
};

可以看出,其通过调用Dispatcher::call()来完成任务,如下:

template <class Return, class... Args>
C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(
    const TypedOperatorHandle<Return(Args...)>& op,
    Args... args) const {
  auto dispatchKeySet =
      op.operatorDef_->op.dispatchKeyExtractor()
          .template getDispatchKeySetUnboxed<Args...>(args...);
  ...
  const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
  ...
  return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}

call()首先会计算DispatchKeySet, 然后再根据DispatchKeySet计算出KernelFunction表列索引,之后再根据列索引从OperatorEntry中查找出KernelFunction, 最后再执行之。

除了Dispatcher::call()外,Dispatcher还有redispatch(), 其本质上和call()相同,只是其从入参获取DispatchKeySet, 而不是根据入参等来计算之,如下:

template <class Return, class... Args>
inline Return Dispatcher::redispatch(
    const TypedOperatorHandle<Return(Args...)>& op,
    DispatchKeySet currentDispatchKeySet,
    Args... args) const {
  ...
  const KernelFunction& kernel =
      op.operatorDef_->op.lookup(currentDispatchKeySet);
  return kernel.template call<Return, Args...>(
      op, currentDispatchKeySet, std::forward<Args>(args)...);
}

2.3.3. 计算DispatchKeySet

Dispatcher::call()会使用OperatorEntry::dispatchKeyExtractor来根据入参tensor/thread local变量/DispatchKeyExtractor中的算子自定义逻辑来综合计算出DispatchKeySet (实现来multiple-dispatch),如下:

template <class... Args>
DispatchKeySet DispatchKeyExtractor::getDispatchKeySetUnboxed() const {
  auto ks = detail::multi_dispatch_key_set(args...);
  // Keys that are fallthrough should be skipped
  if (requiresBitsetPerBackend_) {
    c10::impl::LocalDispatchKeySet tls =
        c10::impl::tls_local_dispatch_key_set();
    auto backend_idx =
        ((ks | tls.included_) - tls.excluded_).getBackendIndex();
    return impl::computeDispatchKeySet(
        ks, nonFallthroughKeysPerBackend_[backend_idx]);
  } else {
    return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
  }
}

可以看出,getDispatchKeySetUnboxed()首先会利用detail::multidispatchkeyset(args…)来从融合入参中的DispatchKey, 然后融合LocalDispatchKeySet, 最后再融合DispatchKeyExtractor的算子自定义nonFallthroughKeys_. multidispatchkeyset()的计算逻辑如下:

template <typename... Args>
DispatchKeySet multi_dispatch_key_set(const Args&... args) {
  return MultiDispatchKeySet().apply(args...).ts;
}

struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
  DispatchKeySet ts;
  void operator()(const at::Tensor& x) {
    ts = ts | x.key_set();
  }
};

struct C10_API TensorImpl : public c10::intrusive_ptr_target {
 private:
  DispatchKeySet key_set_;
};

可以看出,multidispatchkeyset()本质上为对所有入参tensor对象的keyset做位或运算。 computeDispatchKeySet()的计算逻辑如下:

inline DispatchKeySet computeDispatchKeySet(
    DispatchKeySet ks,
    DispatchKeySet key_mask) {
  c10::impl::LocalDispatchKeySet local =
      c10::impl::tls_local_dispatch_key_set();
  return (((ks | local.included_) - local.excluded_) & key_mask);
}

除了​DispatchKeyExtractor::getDispatchKeySetUnboxed()​外,DispatchKeyExtractor还有getDispatchKeySetBoxed(), 其本质上和getDispatchKeySetUnboxed()相同,只是其是从JIT stack中获取入参tensor,如下:

DispatchKeySet DispatchKeyExtractor::getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
  DispatchKeySet ks;
  dispatch_arg_indices_reverse_.for_each_set_bit([&](size_t
                                                         reverse_arg_index) {
    const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
    ...
}

2.3.4. 查找KernelFunction

在得到DispatchKeySet后,​OperatorEntry::lookup()​先通过​DispatchKeySet::getDispatchTableIndexForDispatchKeySet()​来计算出其对应的KernelFunction表列索引,然后再直接查dispatchTable数组以获得KernelFunction, 如下:

const KernelFunction& OperatorEntry::lookup(DispatchKeySet ks) const {
  const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
  if (C10_UNLIKELY(idx == -1)) {
    reportError(ks.highestPriorityTypeId());
  }
  const auto& kernel = dispatchTable_[idx];
  ...
  return kernel;
}

2.3.5. 调用KernelFunction

在得到KernelFunction后,Dispatcher将执行KernelFunction::call(), 其逻辑如下:

template <class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(
  const OperatorHandle& opHandle,
  DispatchKeySet dispatchKeySet,
  Args... args) const {
  ...
  auto* functor = boxed_kernel_func_.getFunctor();
  return callUnboxedKernelFunction<Return, Args...>(
      unboxed_kernel_func_,
      functor,
      dispatchKeySet,
      std::forward<Args>(args)...);
  ...
}

可以看出,其将通过调用callUnboxedKernelFunction()来完成任务,如下:

template <class Return, class... Args>
inline Return callUnboxedKernelFunction(
    void* unboxed_kernel_func,
    OperatorKernel* functor,
    DispatchKeySet dispatchKeySet,
    Args&&... args) {
  using ActualSignature = Return(OperatorKernel*, DispatchKeySet, Args...);
  ActualSignature* func =
      reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
  return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
}

template <class KernelFunctor, class ReturnType, class... ParameterTypes>
struct wrap_kernel_functor_unboxed_<
    KernelFunctor,
    ReturnType(DispatchKeySet, ParameterTypes...)>
    final {
  static ReturnType call(
      OperatorKernel* functor,
      DispatchKeySet dispatchKeySet,
      ParameterTypes... args) {
    KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
    return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...);
  }
};

template <class FuncPtr, class ReturnType, class... Parameters>
class WrapFunctionIntoFunctor_<
    FuncPtr,
    ReturnType,
    guts::typelist::typelist<Parameters...>>
    final : public c10::OperatorKernel {
 public:
  C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) {
    return (*FuncPtr::func_ptr())(std::forward<Parameters>(args)...);
  }
};

可以看出,callUnboxedKernelFunction()在调用到KerneFunction的最终目标前经过来多次完美转发。 对于​tensor1 + tensor2​, add算子第一次执行的KernelFunction最终将执行​torch::autograd::unamed::VariableType::add_Tensor()​.

2.3.6. redispatch

torch::autograd::unamed::VariableType::add_Tensor()​的实现如下,

at::Tensor torch::autograd::unamed::VariableType::add_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  ...
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::redispatch::add(ks & c10::after_autograd_keyset, self_, other_, alpha);
  })();
  ...
}

可以看出,其首先使用​at::AutoDispatchBelowADInplaceOrView guard;​屏蔽掉autograd相关DispatchKeySet, 这可以使得后面Dispatcher流程最终查询到add算子的其它KernelFunction. 然后,其将执行at::redispatch::add(), 如下:

inline at::Tensor at::redispatch::add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
    return at::_ops::add_Tensor::redispatch(dispatchKeySet, self, other, alpha);
}

at::Tensor add_Tensor::redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
    static auto op = create_add_Tensor_typed_handle();
    return op.redispatch(dispatchKeySet, self, other, alpha);
}

C10_ALWAYS_INLINE Return
redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
  return c10::Dispatcher::singleton().redispatch<Return, Args...>(
      *this, currentDispatchKeySet, std::forward<Args>(args)...);
}

最终,将执行到前面的介绍过的Dispatcher::redispatch(), 后面的执行逻辑和前面描述过流程大同小异,因此本节源码分析至此而止。