diff --git a/codegen/tools/gen_max_kernel_num.py b/codegen/tools/gen_max_kernel_num.py index 78e1abe2e1e..4297e7e9299 100644 --- a/codegen/tools/gen_max_kernel_num.py +++ b/codegen/tools/gen_max_kernel_num.py @@ -10,7 +10,12 @@ it as a C header. Total = sum of (op, kernel_key) variants across all input YAMLs - + prim ops always registered by kernels/prim_ops/register_prim_ops.cpp. + + prim ops registered by kernels/prim_ops/register_prim_ops.cpp. + +The prim-ops contribution is counted from the .cpp source by default. When +ET_PRIM_OPS_SELECTIVE_BUILD is active (only some prim ops compile in), pass +--selected-prim-ops-header instead so the count reflects what's actually +linked. See runtime/kernel/operator_registry.cpp for how the emitted header is consumed and the full precedence order. Users that register kernels outside @@ -53,6 +58,12 @@ ) PRIM_OPS_KERNEL_RE = re.compile(r"\bKernel\s*\(") +# Matches the `#define INCLUDE_` lines emitted by gen_selected_prim_ops.py +# into selected_prim_ops.h. Each define gates one prim op entry in +# register_prim_ops.cpp under ET_PRIM_OPS_SELECTIVE_BUILD, so the count of +# defines equals the number of prim ops actually compiled in. +SELECTED_PRIM_OPS_DEFINE_RE = re.compile(r"^\s*#\s*define\s+INCLUDE_\S+", re.MULTILINE) + def _count_prim_ops(prim_ops_source: Path) -> int: source = prim_ops_source.read_text() @@ -72,6 +83,10 @@ def _count_prim_ops(prim_ops_source: Path) -> int: return count +def _count_selected_prim_ops(selected_prim_ops_header: Path) -> int: + return len(SELECTED_PRIM_OPS_DEFINE_RE.findall(selected_prim_ops_header.read_text())) + + def _count_yaml_kernels(yaml_path: Path) -> Optional[int]: """Returns the kernel count for one YAML, or None if the YAML opts into include_all_operators / include_all_overloads (callers should skip the @@ -116,10 +131,18 @@ def _write_if_different(path: Path, content: str) -> None: def gen_max_kernel_num( oplist_yamls: List[Path], - prim_ops_source: Path, output_path: Path, + prim_ops_source: Optional[Path] = None, + selected_prim_ops_header: Optional[Path] = None, ) -> Optional[int]: - total = 0 + if (prim_ops_source is None) == (selected_prim_ops_header is None): + raise ValueError( + "Pass exactly one of prim_ops_source / selected_prim_ops_header. " + "Use selected_prim_ops_header when ET_PRIM_OPS_SELECTIVE_BUILD is " + "active so the counter matches what actually compiles in." + ) + + yaml_kernels = 0 for yaml_path in oplist_yamls: yaml_count = _count_yaml_kernels(yaml_path) if yaml_count is None: @@ -130,9 +153,26 @@ def gen_max_kernel_num( ) _write_if_different(output_path, OPT_OUT_HEADER) return None - total += yaml_count + yaml_kernels += yaml_count + + # No selective build deps means none of the binary's kernel libraries went + # through this counter — auto-sizing would produce a too-small registry. + # Emit the opt-out header so operator_registry.cpp falls through to the + # compile-time default. + if yaml_kernels == 0: + print( + "gen_max_kernel_num: no kernels enumerated across input YAMLs; " + "emitting opt-out header (registry will use default size).", + file=sys.stderr, + ) + _write_if_different(output_path, OPT_OUT_HEADER) + return None - total += _count_prim_ops(prim_ops_source) + if selected_prim_ops_header is not None: + total = yaml_kernels + _count_selected_prim_ops(selected_prim_ops_header) + else: + assert prim_ops_source is not None + total = yaml_kernels + _count_prim_ops(prim_ops_source) _write_if_different(output_path, HEADER_TEMPLATE.format(count=total)) return total @@ -147,11 +187,19 @@ def main(argv: List[str]) -> None: required=True, help="Path to a selected_operators.yaml. May be repeated.", ) - parser.add_argument( + prim_ops_group = parser.add_mutually_exclusive_group(required=True) + prim_ops_group.add_argument( "--prim-ops-source", "--prim_ops_source", - required=True, - help="Path to kernels/prim_ops/register_prim_ops.cpp.", + help="Path to kernels/prim_ops/register_prim_ops.cpp. Use this when " + "ET_PRIM_OPS_SELECTIVE_BUILD is not active (all prim ops compile in).", + ) + prim_ops_group.add_argument( + "--selected-prim-ops-header", + "--selected_prim_ops_header", + help="Path to the aggregated selected_prim_ops.h emitted by " + "gen_selected_prim_ops.py. Use this when ET_PRIM_OPS_SELECTIVE_BUILD " + "is active so only enabled prim ops are counted.", ) parser.add_argument( "--output-path", @@ -163,8 +211,13 @@ def main(argv: List[str]) -> None: count = gen_max_kernel_num( oplist_yamls=[Path(p) for p in args.oplist_yaml], - prim_ops_source=Path(args.prim_ops_source), output_path=Path(args.output_path), + prim_ops_source=Path(args.prim_ops_source) if args.prim_ops_source else None, + selected_prim_ops_header=( + Path(args.selected_prim_ops_header) + if args.selected_prim_ops_header + else None + ), ) if count is not None: print(f"gen_max_kernel_num: wrote {args.output_path} (count={count})") diff --git a/codegen/tools/targets.bzl b/codegen/tools/targets.bzl index b9d0100d8a2..298da8e77eb 100644 --- a/codegen/tools/targets.bzl +++ b/codegen/tools/targets.bzl @@ -188,6 +188,26 @@ def define_common_targets(is_fbcode = False): _is_external_target = True, ) + runtime.python_library( + name = "gen_max_kernel_num_lib", + srcs = ["gen_max_kernel_num.py"], + base_module = "executorch.codegen.tools", + visibility = ["//executorch/..."], + ) + + runtime.python_binary( + name = "gen_max_kernel_num", + main_module = "executorch.codegen.tools.gen_max_kernel_num", + package_style = "inplace", + visibility = [ + "PUBLIC", + ], + deps = [ + ":gen_max_kernel_num_lib", + ], + _is_external_target = True, + ) + runtime.cxx_python_extension( name = "selective_build", diff --git a/codegen/tools/test/test_gen_max_kernel_num.py b/codegen/tools/test/test_gen_max_kernel_num.py index 1b701ad96d6..be013bc117d 100644 --- a/codegen/tools/test/test_gen_max_kernel_num.py +++ b/codegen/tools/test/test_gen_max_kernel_num.py @@ -12,6 +12,7 @@ from executorch.codegen.tools.gen_max_kernel_num import ( _count_prim_ops, + _count_selected_prim_ops, _count_yaml_kernels, gen_max_kernel_num, ) @@ -64,6 +65,91 @@ def test_counts_prim_ops_errors_when_array_empty(self) -> None: with self.assertRaises(RuntimeError): _count_prim_ops(empty_array) + def test_counts_selected_prim_ops_from_header(self) -> None: + header = self.tmp / "selected_prim_ops.h" + header.write_text( + "#pragma once\n" + "#define INCLUDE_ATEN_SYM_SIZE_INT\n" + "#define INCLUDE_EXECUTORCH_PRIM_ADD_INT_INT\n" + "// not a define: INCLUDE_FOO\n" + "#define INCLUDE_EXECUTORCH_PRIM_MUL_INT_INT\n" + ) + self.assertEqual(_count_selected_prim_ops(header), 3) + + def test_rejects_both_prim_ops_inputs(self) -> None: + yaml_path = self.tmp / "selected_operators.yaml" + _write_yaml( + yaml_path, + { + "operators": {"aten::add.out": {}}, + "et_kernel_metadata": {"aten::add.out": ["v1/6"]}, + }, + ) + header = self.tmp / "selected_prim_ops.h" + header.write_text("#define INCLUDE_FOO\n") + with self.assertRaises(ValueError): + gen_max_kernel_num( + oplist_yamls=[yaml_path], + prim_ops_source=self.prim_ops_source, + selected_prim_ops_header=header, + output_path=self.output, + ) + + def test_rejects_neither_prim_ops_input(self) -> None: + yaml_path = self.tmp / "selected_operators.yaml" + _write_yaml( + yaml_path, + { + "operators": {"aten::add.out": {}}, + "et_kernel_metadata": {"aten::add.out": ["v1/6"]}, + }, + ) + with self.assertRaises(ValueError): + gen_max_kernel_num( + oplist_yamls=[yaml_path], + output_path=self.output, + ) + + def test_empty_yaml_writes_opt_out_header(self) -> None: + yaml_path = self.tmp / "selected_operators.yaml" + _write_yaml(yaml_path, {"operators": {}, "et_kernel_metadata": {}}) + total = gen_max_kernel_num( + oplist_yamls=[yaml_path], + prim_ops_source=self.prim_ops_source, + output_path=self.output, + ) + self.assertIsNone(total) + self.assertTrue(self.output.exists()) + self.assertNotIn( + "#define EXECUTORCH_SELECTED_MAX_KERNEL_NUM", + self.output.read_text(), + ) + + def test_end_to_end_with_selected_prim_ops_header(self) -> None: + yaml_path = self.tmp / "selected_operators.yaml" + _write_yaml( + yaml_path, + { + "operators": {"aten::add.out": {}}, + "et_kernel_metadata": {"aten::add.out": ["v1/6"]}, + }, + ) + header = self.tmp / "selected_prim_ops.h" + header.write_text( + "#define INCLUDE_ATEN_SYM_SIZE_INT\n" + "#define INCLUDE_EXECUTORCH_PRIM_ADD_INT_INT\n" + ) + total = gen_max_kernel_num( + oplist_yamls=[yaml_path], + selected_prim_ops_header=header, + output_path=self.output, + ) + self.assertEqual(total, 1 + 2) + self.assertIn( + "#define EXECUTORCH_SELECTED_MAX_KERNEL_NUM 3", + self.output.read_text(), + ) + def test_counts_single_variant_per_op(self) -> None: yaml_path = self.tmp / "selected_operators.yaml" _write_yaml( diff --git a/runtime/kernel/selective_build.bzl b/runtime/kernel/selective_build.bzl new file mode 100644 index 00000000000..c42e87c77f9 --- /dev/null +++ b/runtime/kernel/selective_build.bzl @@ -0,0 +1,121 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load( + "@fbsource//xplat/executorch/runtime/kernel:targets.bzl", + "operator_registry_preprocessor_flags", +) + +# Layout of the per-binary header tree, matching the angle-bracket includes +# operator_registry.cpp uses (``). +_HEADER_DIR = "executorch/runtime/kernel" +_MAX_KERNEL_NUM_HEADER = _HEADER_DIR + "/selected_max_kernel_num.h" +_OP_REGISTRY_HEADER = _HEADER_DIR + "/operator_registry.h" + +def gen_max_kernel_num_genrule( + name, + oplist_yaml_target, + selected_prim_ops_header_target = None, + platforms = "CXX"): + """Run gen_max_kernel_num on a selected_operators.yaml and emit a header + that defines EXECUTORCH_SELECTED_MAX_KERNEL_NUM. + + When selected_prim_ops_header_target is provided (i.e. ET_PRIM_OPS_SELECTIVE_BUILD + is active for this binary), the prim ops contribution is counted from that + header so it matches what actually compiles in. Otherwise the count comes + from parsing register_prim_ops.cpp directly. + """ + + # Write the header flat at the artifact root so consumers can reference + # it as $(location :name)/selected_max_kernel_num.h (mirrors the + # selected_operators.yaml / selected_prim_ops.h conventions used by the + # adjacent genrules). + cmd = ( + "$(exe //executorch/codegen/tools:gen_max_kernel_num) " + + "--oplist-yaml=$(location {})/selected_operators.yaml ".format(oplist_yaml_target) + + "--output-path=$OUT/selected_max_kernel_num.h " + ) + if selected_prim_ops_header_target: + cmd += "--selected-prim-ops-header=$(location {})/selected_prim_ops.h".format( + selected_prim_ops_header_target, + ) + else: + cmd += "--prim-ops-source=$(location //executorch/kernels/prim_ops:prim_ops_sources)/register_prim_ops.cpp" + + runtime.genrule( + name = name, + cmd = cmd, + outs = {"selected_max_kernel_num.h": ["selected_max_kernel_num.h"]}, + default_outs = ["."], + platforms = platforms, + ) + +def operator_registry_selective( + name, + selected_max_kernel_num_header_target, + aten_suffix = "", + platforms = "CXX", + **kwargs): + """Per-binary operator_registry variant whose registry capacity is sized + to the kernels its consumer actually selected. + + Stages operator_registry.cpp + operator_registry.h + the generated + selected_max_kernel_num.h in a single artifact tree, then compiles the + .cpp with all three headers visible at the expected + `` paths. operator_registry.cpp's existing + `__has_include` ladder picks up EXECUTORCH_SELECTED_MAX_KERNEL_NUM. A + user-supplied `-c executorch.max_kernel_num=N` still wins via the same + preprocessor flags the shared target uses. + + NOTE: the operator registry is intentionally a global; this target + defines the same external-linkage symbols as the shared + `//executorch/runtime/kernel:operator_registry`. Linking both into one + binary produces ODR / duplicate-symbol errors. Consumers must arrange + for only one to be linked transitively, which is why + `executorch_generated_lib(auto_size_kernel_registry = True)` is opt-in. + """ + src_target = "//executorch/runtime/kernel:operator_registry_sources" + hdr_target = "//executorch/runtime/kernel:operator_registry_headers" + source_name = "operator_registry.cpp" + genrule_dep_name = name + "_operator_registry_srcs_copy" + + runtime.genrule( + name = genrule_dep_name, + cmd = " && ".join([ + "mkdir -p $OUT/{}".format(_HEADER_DIR), + "cp -f $(location {})/{} $OUT/{}".format(src_target, source_name, source_name), + "cp -f $(location {})/operator_registry.h $OUT/{}".format(hdr_target, _OP_REGISTRY_HEADER), + "cp -f $(location {})/selected_max_kernel_num.h $OUT/{}".format( + selected_max_kernel_num_header_target, + _MAX_KERNEL_NUM_HEADER, + ), + ]), + outs = { + source_name: [source_name], + "operator_registry.h": [_OP_REGISTRY_HEADER], + "selected_max_kernel_num.h": [_MAX_KERNEL_NUM_HEADER], + }, + default_outs = ["."], + platforms = platforms, + ) + + runtime.cxx_library( + name = name, + srcs = [":" + genrule_dep_name + "[" + source_name + "]"], + exported_headers = { + _OP_REGISTRY_HEADER: ":" + genrule_dep_name + "[operator_registry.h]", + _MAX_KERNEL_NUM_HEADER: ":" + genrule_dep_name + "[selected_max_kernel_num.h]", + }, + # The dict keys above are already fully-qualified include paths (e.g. + # `executorch/runtime/kernel/operator_registry.h`); env_interface.bzl + # would otherwise prepend `executorch//` to them. + header_namespace = "", + visibility = ["PUBLIC"], + # @lint-ignore BUCKLINT link_whole, the registry contains a global table. + link_whole = True, + preprocessor_flags = operator_registry_preprocessor_flags(), + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core:evalue" + aten_suffix, + ], + platforms = platforms, + **kwargs + ) diff --git a/runtime/kernel/targets.bzl b/runtime/kernel/targets.bzl index 123dc1ac8da..4d623b15681 100644 --- a/runtime/kernel/targets.bzl +++ b/runtime/kernel/targets.bzl @@ -1,6 +1,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_aten_mode_options", "runtime") -def _operator_registry_preprocessor_flags(): +def operator_registry_preprocessor_flags(): max_kernel_num = native.read_config("executorch", "max_kernel_num", None) if max_kernel_num != None: return select({ @@ -24,6 +24,20 @@ def define_common_targets(): TARGETS and BUCK files that call this function. """ + # Exposed so operator_registry_selective() in + # shim_et/xplat/executorch/runtime/kernel/selective_build.bzl can stage the + # source alongside a per-binary selected_max_kernel_num.h. + runtime.filegroup( + name = "operator_registry_sources", + srcs = ["operator_registry.cpp"], + visibility = ["PUBLIC"], + ) + runtime.filegroup( + name = "operator_registry_headers", + srcs = ["operator_registry.h"], + visibility = ["PUBLIC"], + ) + runtime.cxx_library( name = "operator_registry_MAX_NUM_KERNELS_TEST_ONLY", srcs = ["operator_registry.cpp"], @@ -62,7 +76,7 @@ def define_common_targets(): "//executorch/runtime/core:core", "//executorch/runtime/core:evalue" + aten_suffix, ], - preprocessor_flags = _operator_registry_preprocessor_flags(), + preprocessor_flags = operator_registry_preprocessor_flags(), ) runtime.cxx_library( diff --git a/shim_et/xplat/executorch/codegen/codegen.bzl b/shim_et/xplat/executorch/codegen/codegen.bzl index 8b17171ec4e..9f5cdc97a33 100644 --- a/shim_et/xplat/executorch/codegen/codegen.bzl +++ b/shim_et/xplat/executorch/codegen/codegen.bzl @@ -8,6 +8,11 @@ load( load("@fbsource//xplat/executorch/kernels/optimized:op_registration_util.bzl", "optimized_source_list") load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "portable_source_list") load("@fbsource//xplat/executorch/kernels/prim_ops:selective_build.bzl", "prim_ops_registry_selective") +load( + "@fbsource//xplat/executorch/runtime/kernel:selective_build.bzl", + "gen_max_kernel_num_genrule", + "operator_registry_selective", +) # Headers that declare the function signatures of the C++ functions that # map to entries in functions.yaml and custom_ops.yaml. @@ -82,6 +87,47 @@ ScalarType = enum( "Uint64", ) +def _combined_prim_ops_header_target_name(name): + """Single source of truth for the deterministic name of the per-binary + aggregated selected_prim_ops.h target. Both _get_prim_ops_registry_target + (which creates it) and _get_operator_registry_target (which references it + when include_all_prim_ops=False) must use the same name.""" + return name + "_combined_prim_ops_header" + +def _get_operator_registry_target( + name, + oplist_dir_name, + aten_suffix, + platforms, + include_all_prim_ops): + """Create a per-binary operator_registry variant whose registry capacity is + sized to the kernels actually selected by this binary's et_operator_library + deps. When include_all_prim_ops is False, also threads the binary's + combined selected_prim_ops.h into the counter so prim ops compiled out + under ET_PRIM_OPS_SELECTIVE_BUILD aren't counted. + """ + max_kernel_num_genrule_name = name + "_selected_max_kernel_num" + + selected_prim_ops_header = None + if not include_all_prim_ops: + selected_prim_ops_header = ":" + _combined_prim_ops_header_target_name(name) + + gen_max_kernel_num_genrule( + name = max_kernel_num_genrule_name, + oplist_yaml_target = ":" + oplist_dir_name, + selected_prim_ops_header_target = selected_prim_ops_header, + platforms = platforms, + ) + + operator_registry_target_name = name + "_operator_registry" + aten_suffix + operator_registry_selective( + name = operator_registry_target_name, + selected_max_kernel_num_header_target = ":" + max_kernel_num_genrule_name, + aten_suffix = aten_suffix, + platforms = platforms, + ) + return ":" + operator_registry_target_name + def _get_prim_ops_registry_target(name, deps, aten_suffix, platforms): """ Helper function to determine which prim ops registry target to use. @@ -100,7 +146,7 @@ def _get_prim_ops_registry_target(name, deps, aten_suffix, platforms): # If selective build targets are specified, create a selective prim ops registry # Create a selective prim ops registry using the existing function selective_prim_ops_registry_name = name + "_selected_prim_ops_registry" - combined_prim_ops_header_target_name = name + "_combined_prim_ops_header" + combined_prim_ops_header_target_name = _combined_prim_ops_header_target_name(name) selected_prim_operators_genrule(combined_prim_ops_header_target_name, deps, platforms) # Use the existing prim_ops_registry_selective function @@ -825,7 +871,8 @@ def executorch_generated_lib( compatible_with = None, expose_operator_symbols = False, support_exceptions = True, - include_all_prim_ops = True): + include_all_prim_ops = True, + auto_size_kernel_registry = False): """Emits 0-3 C++ library targets (in fbcode or xplat) containing code to dispatch the operators specified in the provided yaml files. @@ -894,6 +941,23 @@ def executorch_generated_lib( include_all_prim_ops: If true, include all prim ops in the generated library. This option allows for selecting only some prim ops to reduce code size for extremely constrained environments. For selecting only some prim ops, see examples in //executorch/examples/selective_build + auto_size_kernel_registry: Opt-in (default False). When True, emit a + per-binary operator_registry variant whose registry capacity is + sized to the kernels selected via the binary's + et_operator_library deps. Saves ~24 KiB / ~48 KiB of BSS on + 32-/64-bit targets compared to the kMaxOperators*kMaxKernelsPerOp + default. A user-supplied `-c executorch.max_kernel_num=N` still + overrides via the existing preprocessor-flag path. + + CAVEAT: the operator registry is intentionally a global, so + this variant defines the same external-linkage symbols as the + shared //executorch/runtime/kernel:operator_registry. Any + transitive dep on the shared target (e.g. the standard + runtime/executor program library, the shared prim_ops_registry + via include_all_prim_ops=True, extension/kernel_util, + extension/pybindings) will collide at link time. Consumers + turning this on are responsible for ensuring their binary's + link graph contains exactly one operator_registry target. """ _compat_kwargs = {} if compatible_with != None: @@ -1071,6 +1135,17 @@ def executorch_generated_lib( else: prim_ops_registry_target = _get_prim_ops_registry_target(name, deps, aten_suffix, platforms) + if auto_size_kernel_registry: + operator_registry_target = _get_operator_registry_target( + name = name, + oplist_dir_name = oplist_dir_name, + aten_suffix = aten_suffix, + platforms = platforms, + include_all_prim_ops = include_all_prim_ops, + ) + else: + operator_registry_target = "//executorch/runtime/kernel:operator_registry" + aten_suffix + runtime.cxx_library( name = lib_name, srcs = [ @@ -1094,7 +1169,7 @@ def executorch_generated_lib( "ovr_config//os:windows": [], }) + compiler_flags, deps = [ - "//executorch/runtime/kernel:operator_registry" + aten_suffix, + operator_registry_target, prim_ops_registry_target, # Use the appropriate prim ops registry "//executorch/runtime/core:evalue" + aten_suffix, "//executorch/codegen:macros",