diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index 1bfe3dc45ab..a5803cd50a6 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -249,7 +249,8 @@ bool registry_has_op_function( Result get_op_function_from_registry( const char* name, - Span meta_list) { + Span meta_list, + Span kernel_list) { std::array key_string; Error err = internal::make_kernel_key_string( meta_list, key_string.data(), key_string.size()); @@ -260,24 +261,31 @@ Result get_op_function_from_registry( KernelKey kernel_key = KernelKey(key_string.data()); int32_t fallback_idx = -1; - for (size_t idx = 0; idx < num_registered_kernels; idx++) { - if (strcmp(registered_kernels[idx].name_, name) == 0) { - if (registered_kernels[idx].kernel_key_ == kernel_key) { - return registered_kernels[idx].op_; + for (size_t idx = 0; idx < kernel_list.size(); idx++) { + if (strcmp(kernel_list[idx].name_, name) == 0) { + if (kernel_list[idx].kernel_key_ == kernel_key) { + return kernel_list[idx].op_; } - if (registered_kernels[idx].kernel_key_.is_fallback()) { + if (kernel_list[idx].kernel_key_.is_fallback()) { fallback_idx = idx; } } } if (fallback_idx != -1) { - return registered_kernels[fallback_idx].op_; + return kernel_list[fallback_idx].op_; } ET_LOG(Error, "kernel '%s' not found.", name); ET_LOG_TENSOR_META(meta_list); return Error::OperatorMissing; } +Result get_op_function_from_registry( + const char* name, + Span meta_list) { + return get_op_function_from_registry( + name, meta_list, get_registered_kernels()); +} + Span get_registered_kernels() { return {registered_kernels, num_registered_kernels}; } diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index 1369ead1024..4f69bb75a07 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -233,6 +233,15 @@ ::executorch::runtime::Result get_op_function_from_registry( const char* name, Span meta_list = {}); +/** + * Returns the operator with a given name and TensorMeta list from the provided + * kernel list instead of the global registry. + */ +::executorch::runtime::Result get_op_function_from_registry( + const char* name, + Span meta_list, + Span kernel_list); + /** * Returns all registered kernels. */ diff --git a/runtime/kernel/test/operator_registry_test.cpp b/runtime/kernel/test/operator_registry_test.cpp index 5bc411b43ee..774b732cd2d 100644 --- a/runtime/kernel/test/operator_registry_test.cpp +++ b/runtime/kernel/test/operator_registry_test.cpp @@ -387,6 +387,57 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) { ASSERT_EQ(val_2, 50); } +TEST_F(OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) { + std::array buf; + Error err = make_kernel_key( + {{ScalarType::Long, {0, 1, 2, 3}}}, buf.data(), buf.size()); + ASSERT_EQ(err, Error::Ok); + KernelKey long_key = KernelKey(buf.data()); + + Kernel kernels[] = { + Kernel( + "test::provided_kernel_list", + KernelKey{}, + [](KernelRuntimeContext& context, Span stack) { + (void)context; + *(stack[0]) = Scalar(50); + }), + Kernel( + "test::provided_kernel_list", + long_key, + [](KernelRuntimeContext& context, Span stack) { + (void)context; + *(stack[0]) = Scalar(100); + }), + }; + Span kernels_span(kernels); + + Tensor::DimOrderType dims[] = {0, 1, 2, 3}; + auto dim_order_type = Span(dims, 4); + TensorMeta long_meta[] = {TensorMeta(ScalarType::Long, dim_order_type)}; + Span long_kernel_key(long_meta); + + auto run_kernel = [](OpFunction func) { + EValue value = Scalar(0); + EValue* stack[] = {&value}; + KernelRuntimeContext context{}; + func(context, Span(stack)); + return value.toScalar().to(); + }; + + Result specialized_func = get_op_function_from_registry( + "test::provided_kernel_list", long_kernel_key, kernels_span); + ASSERT_EQ(specialized_func.error(), Error::Ok); + EXPECT_EQ(run_kernel(*specialized_func), 100); + + TensorMeta float_meta[] = {TensorMeta(ScalarType::Float, dim_order_type)}; + Span float_kernel_key(float_meta); + Result fallback_func = get_op_function_from_registry( + "test::provided_kernel_list", float_kernel_key, kernels_span); + ASSERT_EQ(fallback_func.error(), Error::Ok); + EXPECT_EQ(run_kernel(*fallback_func), 50); +} + TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) { std::array buf_long_contiguous; Error err = make_kernel_key(