Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions runtime/kernel/operator_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ bool registry_has_op_function(

Result<OpFunction> get_op_function_from_registry(
const char* name,
Span<const TensorMeta> meta_list) {
Span<const TensorMeta> meta_list,
Span<const Kernel> kernel_list) {
std::array<char, internal::kKernelKeyBufSize> key_string;
Error err = internal::make_kernel_key_string(
meta_list, key_string.data(), key_string.size());
Expand All @@ -260,24 +261,31 @@ Result<OpFunction> get_op_function_from_registry(
KernelKey kernel_key = KernelKey(key_string.data());

int32_t fallback_idx = -1;
for (size_t idx = 0; idx < num_registered_kernels; idx++) {
if (strcmp(registered_kernels[idx].name_, name) == 0) {
if (registered_kernels[idx].kernel_key_ == kernel_key) {
return registered_kernels[idx].op_;
for (size_t idx = 0; idx < kernel_list.size(); idx++) {
if (strcmp(kernel_list[idx].name_, name) == 0) {
if (kernel_list[idx].kernel_key_ == kernel_key) {
return kernel_list[idx].op_;
}
if (registered_kernels[idx].kernel_key_.is_fallback()) {
if (kernel_list[idx].kernel_key_.is_fallback()) {
fallback_idx = idx;
}
}
}
if (fallback_idx != -1) {
return registered_kernels[fallback_idx].op_;
return kernel_list[fallback_idx].op_;
}
ET_LOG(Error, "kernel '%s' not found.", name);
ET_LOG_TENSOR_META(meta_list);
return Error::OperatorMissing;
}

Result<OpFunction> get_op_function_from_registry(
const char* name,
Span<const TensorMeta> meta_list) {
return get_op_function_from_registry(
name, meta_list, get_registered_kernels());
}

Span<const Kernel> get_registered_kernels() {
return {registered_kernels, num_registered_kernels};
}
Expand Down
9 changes: 9 additions & 0 deletions runtime/kernel/operator_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,15 @@ ::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
const char* name,
Span<const TensorMeta> meta_list = {});

/**
* Returns the operator with a given name and TensorMeta list from the provided
* kernel list instead of the global registry.
*/
::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
const char* name,
Span<const TensorMeta> meta_list,
Span<const Kernel> kernel_list);

/**
* Returns all registered kernels.
*/
Expand Down
53 changes: 53 additions & 0 deletions runtime/kernel/test/operator_registry_test.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -387,6 +387,59 @@
ASSERT_EQ(val_2, 50);
}

TEST_F(OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) {
std::array<char, kKernelKeyBufSize> buf;
Error err = make_kernel_key(
{{ScalarType::Long, {0, 1, 2, 3}}},
buf.data(),
buf.size());
ASSERT_EQ(err, Error::Ok);
KernelKey long_key = KernelKey(buf.data());

Kernel kernels[] = {
Kernel(
"test::provided_kernel_list",
KernelKey{},
[](KernelRuntimeContext& context, Span<EValue*> stack) {
(void)context;
*(stack[0]) = Scalar(50);
}),
Kernel(
"test::provided_kernel_list",
long_key,
[](KernelRuntimeContext& context, Span<EValue*> stack) {
(void)context;
*(stack[0]) = Scalar(100);
}),
};
Span<const Kernel> kernels_span(kernels);

Tensor::DimOrderType dims[] = {0, 1, 2, 3};
auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
TensorMeta long_meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
Span<const TensorMeta> long_kernel_key(long_meta);

auto run_kernel = [](OpFunction func) {
EValue value = Scalar(0);
EValue* stack[] = {&value};
KernelRuntimeContext context{};
func(context, Span<EValue*>(stack));
return value.toScalar().to<int64_t>();
};

Result<OpFunction> specialized_func = get_op_function_from_registry(
"test::provided_kernel_list", long_kernel_key, kernels_span);
ASSERT_EQ(specialized_func.error(), Error::Ok);
EXPECT_EQ(run_kernel(*specialized_func), 100);

TensorMeta float_meta[] = {TensorMeta(ScalarType::Float, dim_order_type)};
Span<const TensorMeta> float_kernel_key(float_meta);
Result<OpFunction> fallback_func = get_op_function_from_registry(
"test::provided_kernel_list", float_kernel_key, kernels_span);
ASSERT_EQ(fallback_func.error(), Error::Ok);
EXPECT_EQ(run_kernel(*fallback_func), 50);
}

TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
std::array<char, kKernelKeyBufSize> buf_long_contiguous;
Error err = make_kernel_key(
Expand Down
Loading