-
Notifications
You must be signed in to change notification settings - Fork 817
feat: Add SVE kernels for TopKV #1256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
morgolock
wants to merge
1
commit into
main
Choose a base branch
from
pr/topkv_sve_kernels
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+525
−18
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| /* | ||
| * Copyright (c) 2026 Arm Limited. | ||
| * | ||
| * SPDX-License-Identifier: MIT | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to | ||
| * deal in the Software without restriction, including without limitation the | ||
| * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or | ||
| * sell copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in all | ||
| * copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| * SOFTWARE. | ||
| */ | ||
| #if defined(__ARM_FEATURE_SVE) | ||
|
|
||
| #include "src/cpu/kernels/topkv/generic/sve/impl.h" | ||
|
|
||
| #include <arm_sve.h> | ||
|
|
||
| namespace arm_compute | ||
| { | ||
| namespace cpu | ||
| { | ||
| namespace detail | ||
| { | ||
|
|
||
| template <> | ||
| inline uint32_t vector_length<float16_t>() | ||
| { | ||
| return static_cast<uint32_t>(svcnth()); | ||
| } | ||
|
|
||
| template <> | ||
| inline uint32_t count_gt_block<float16_t>(const float16_t *ptr, float16_t thr, uint32_t block_elems) | ||
| { | ||
| const svbool_t pg = svwhilelt_b16(static_cast<uint64_t>(0), static_cast<uint64_t>(block_elems)); | ||
| const svfloat16_t v = svld1_f16(pg, ptr); | ||
| const svbool_t gt = svcmpgt_n_f16(pg, v, thr); | ||
| return static_cast<uint32_t>(svcntp_b16(svptrue_b16(), gt)); | ||
| } | ||
|
|
||
| } // namespace detail | ||
|
|
||
| void topkv_fp16_sve(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win) | ||
| { | ||
| detail::topkv_sve_wrapper<float16_t>(predictions, targets, out, k, win); | ||
| } | ||
|
|
||
| // Force instantiation into this TU | ||
| template void | ||
| detail::topkv_sve_wrapper<float16_t>(const ITensor *, const ITensor *, ITensor *, uint32_t, const Window &); | ||
|
|
||
| } // namespace cpu | ||
| } // namespace arm_compute | ||
|
|
||
| #endif // __ARM_FEATURE_SVE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| /* | ||
| * Copyright (c) 2026 Arm Limited. | ||
| * | ||
| * SPDX-License-Identifier: MIT | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to | ||
| * deal in the Software without restriction, including without limitation the | ||
| * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or | ||
| * sell copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in all | ||
| * copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| * SOFTWARE. | ||
| */ | ||
| #if defined(__ARM_FEATURE_SVE) | ||
|
|
||
| #include "src/cpu/kernels/topkv/generic/sve/impl.h" | ||
|
|
||
| #include <arm_sve.h> | ||
| #include <cstdint> | ||
|
|
||
| namespace arm_compute | ||
| { | ||
| namespace cpu | ||
| { | ||
| namespace detail | ||
| { | ||
|
|
||
| template <> | ||
| inline uint32_t vector_length<float>() | ||
| { | ||
| return static_cast<uint32_t>(svcntw()); | ||
| } | ||
|
|
||
| template <> | ||
| inline uint32_t count_gt_block<float>(const float *ptr, float thr, uint32_t block_elems) | ||
| { | ||
| const svbool_t pg = svwhilelt_b32(static_cast<uint64_t>(0), static_cast<uint64_t>(block_elems)); | ||
| const svfloat32_t v = svld1_f32(pg, ptr); | ||
| const svbool_t gt = svcmpgt_n_f32(pg, v, thr); | ||
| return static_cast<uint32_t>(svcntp_b32(svptrue_b32(), gt)); | ||
| } | ||
|
|
||
| } // namespace detail | ||
|
|
||
| void topkv_fp32_sve(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win) | ||
| { | ||
| detail::topkv_sve_wrapper<float>(predictions, targets, out, k, win); | ||
| } | ||
|
|
||
| // Force instantiation into this TU | ||
| template void detail::topkv_sve_wrapper<float>(const ITensor *, const ITensor *, ITensor *, uint32_t, const Window &); | ||
|
|
||
| } // namespace cpu | ||
| } // namespace arm_compute | ||
|
|
||
| #endif // __ARM_FEATURE_SVE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| /* | ||
| * Copyright (c) 2026 Arm Limited. | ||
| * | ||
| * SPDX-License-Identifier: MIT | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to | ||
| * deal in the Software without restriction, including without limitation the | ||
| * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or | ||
| * sell copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in all | ||
| * copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| * SOFTWARE. | ||
| */ | ||
| #ifndef ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H | ||
| #define ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H | ||
|
|
||
| #include "arm_compute/core/Coordinates.h" | ||
| #include "arm_compute/core/Error.h" | ||
| #include "arm_compute/core/Helpers.h" | ||
| #include "arm_compute/core/ITensor.h" | ||
| #include "arm_compute/core/Types.h" | ||
| #include "arm_compute/core/Window.h" | ||
|
|
||
| #include <cstdint> | ||
| #include <cstring> | ||
|
|
||
| namespace arm_compute | ||
| { | ||
| namespace cpu | ||
| { | ||
| namespace detail | ||
| { | ||
|
|
||
| /* | ||
| * Type-specific hooks (declared here, defined in each cpp). | ||
| * | ||
| * - vector_length<Scalar>() | ||
| * Return the SVE vector length in elements for Scalar (no clamping). | ||
| * | ||
| * - count_gt_block<Scalar>(ptr, thr, block_elems) | ||
| * Count how many elements in [ptr, ptr + block_elems) are > thr. | ||
| * Tail-safe via predicate. block_elems is always <= vector_length<Scalar>(). | ||
| * | ||
| t contains the SVE intrinsics | ||
| * (e.g., qasymm8.cpp, qasymm8_signed.cpp, fp16.cpp, fp32.cpp, integer.cpp). | ||
| */ | ||
|
|
||
| template <typename Scalar> | ||
| uint32_t vector_length(); | ||
|
|
||
| template <typename Scalar> | ||
| uint32_t count_gt_block(const Scalar *ptr, Scalar thr, uint32_t block_elems); | ||
|
|
||
| // ---------------------------------------------------------------------------- | ||
| // Generic wrapper (type-agnostic) - uses the above hooks. | ||
| // Semantics (matching TopKV tests you showed): | ||
| // - predictions is N x C | ||
| // - window iterates across output elements (classes) => id.x() == class index c | ||
| // - for each class c, targets[c] gives the sample index t | ||
| // - scan across N samples and compute rank (#samples with value > predictions[t]) | ||
| // - output is U8 boolean: (rank < k) | ||
| // ---------------------------------------------------------------------------- | ||
| template <typename Scalar> | ||
| inline void | ||
| topkv_sve_wrapper(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &window) | ||
| { | ||
| ARM_COMPUTE_ERROR_ON_NULLPTR(predictions, targets, out); | ||
| ARM_COMPUTE_ERROR_ON(k == 0); | ||
|
|
||
| const ITensorInfo *pred_info = predictions->info(); | ||
| const uint32_t N = pred_info->dimension(0); // samples | ||
| const uint32_t C = pred_info->dimension(1); // classes | ||
|
|
||
| const uint32_t vl = vector_length<Scalar>(); // cache once per kernel invocation | ||
|
|
||
| Iterator tgt_it(targets, window); | ||
| Iterator out_it(out, window); | ||
|
|
||
| execute_window_loop( | ||
| window, | ||
| [&](const Coordinates &id) | ||
| { | ||
| const uint32_t c = static_cast<uint32_t>(id.x()); // class index | ||
| ARM_COMPUTE_ERROR_ON(c >= C); | ||
|
|
||
| uint32_t t = {*reinterpret_cast<uint32_t *>(tgt_it.ptr())}; | ||
| ARM_COMPUTE_ERROR_ON(t >= N); | ||
|
|
||
| const Scalar *col_ptr = reinterpret_cast<const Scalar *>(predictions->ptr_to_element(Coordinates(0, c))); | ||
| ARM_COMPUTE_ERROR_ON(col_ptr == nullptr); | ||
|
|
||
| const Scalar thr = col_ptr[t]; | ||
|
|
||
| uint32_t rank = 0; | ||
| uint32_t idx = 0; | ||
|
|
||
| while (idx < N) | ||
| { | ||
| const uint32_t remaining = N - idx; | ||
| const uint32_t bw = (remaining < vl) ? remaining : vl; | ||
|
|
||
| rank += count_gt_block<Scalar>(col_ptr + idx, thr, bw); | ||
|
|
||
| if (rank >= k) | ||
| { | ||
| break; | ||
| } | ||
|
|
||
| idx += bw; | ||
| } | ||
|
|
||
| *reinterpret_cast<uint8_t *>(out_it.ptr()) = static_cast<uint8_t>(rank < k); | ||
| }, | ||
| tgt_it, out_it); | ||
| } | ||
|
|
||
| } // namespace detail | ||
| } // namespace cpu | ||
| } // namespace arm_compute | ||
|
|
||
| #endif // ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The convention in this file is two spaces between each level.