diff --git a/crates/cuda_builder/src/lib.rs b/crates/cuda_builder/src/lib.rs index 6441b827..3620c50a 100644 --- a/crates/cuda_builder/src/lib.rs +++ b/crates/cuda_builder/src/lib.rs @@ -266,6 +266,21 @@ impl CudaBuilder { self } + /// Enable fast math approximations globally (equivalent to NVCC's `--use_fast_math`). + /// Sets `ftz=true`, `fast_sqrt=true`, `fast_div=true`, and `fma_contraction=true`. + /// Individual flags can still be overridden afterward. + /// + /// Note: this sacrifices IEEE 754 compliance for performance. Single-precision + /// division and square root will have up to 2 ULP error, and denormal values + /// will be flushed to zero. + pub fn fast_math(mut self) -> Self { + self.ftz = true; + self.fast_sqrt = true; + self.fast_div = true; + self.fma_contraction = true; + self + } + /// Use a fast approximation for single-precision floating point square root. pub fn fast_sqrt(mut self, fast_sqrt: bool) -> Self { self.fast_sqrt = fast_sqrt;