From 30e817e1a6dcf23d83d77e43ad92a8a4d646e598 Mon Sep 17 00:00:00 2001 From: Alexander Voigt Date: Tue, 24 Mar 2026 07:30:39 +0100 Subject: [PATCH] refactor test of ForwardDiff and ChainRules --- test/ForwardDiff.jl | 37 +++++++++++++++++++++++++++++++++++++ test/Li.jl | 8 -------- test/Li0.jl | 7 ------- test/Li1.jl | 8 -------- test/Li2.jl | 8 -------- test/Li3.jl | 8 -------- test/Li4.jl | 8 -------- test/runtests.jl | 5 +---- 8 files changed, 38 insertions(+), 51 deletions(-) create mode 100644 test/ForwardDiff.jl diff --git a/test/ForwardDiff.jl b/test/ForwardDiff.jl new file mode 100644 index 0000000..7dc1f07 --- /dev/null +++ b/test/ForwardDiff.jl @@ -0,0 +1,37 @@ +if isdefined(Base, :get_extension) + import ForwardDiff + import ChainRulesTestUtils + + @testset "ForwardDiff" begin + @test ForwardDiff.derivative(PolyLog.li0, float(pi)) == 1/(1 - pi)^2 + @test ForwardDiff.derivative(PolyLog.li0, 0.0) == 1 + @test ForwardDiff.derivative(PolyLog.li0, 1.0) == Inf + + @test ForwardDiff.derivative(PolyLog.reli1, float(pi)) == 1/(1 - pi) + @test ForwardDiff.derivative(PolyLog.reli1, 0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli1, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli1, float(pi)) + + @test ForwardDiff.derivative(PolyLog.reli2, float(pi)) == PolyLog.reli1(pi)/pi + @test ForwardDiff.derivative(PolyLog.reli2, 0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli2, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli2, float(pi)) + + @test ForwardDiff.derivative(PolyLog.reli3, float(pi)) == PolyLog.reli2(pi)/pi + @test ForwardDiff.derivative(PolyLog.reli3, 0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli3, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli3, float(pi)) + + @test ForwardDiff.derivative(PolyLog.reli4, float(pi)) == PolyLog.reli3(pi)/pi + @test ForwardDiff.derivative(PolyLog.reli4, 0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli4, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli4, float(pi)) + + for n in vcat(collect(-10:10), [100, 1000000]) + @test ForwardDiff.derivative(z -> PolyLog.reli(n, z), float(pi)) == PolyLog.reli(n - 1, pi)/pi + @test ForwardDiff.derivative(z -> PolyLog.reli(n, z), 0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli, n, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli, n, float(pi)) + end + end +end diff --git a/test/Li.jl b/test/Li.jl index a98967b..b89b588 100644 --- a/test/Li.jl +++ b/test/Li.jl @@ -95,14 +95,6 @@ end @test PolyLog.li(n, 1//1 + 0//1im) ≈ zeta @test PolyLog.li(n, 1 + 0im) ≈ zeta @test PolyLog.li(n, BigFloat("1.0") + 0im) == PolyLog.zeta(n, BigFloat) - - # ForwardDiff Test - if isdefined(Base, :get_extension) - @test ForwardDiff.derivative(z -> PolyLog.reli(n, z), float(pi)) == PolyLog.reli(n - 1, pi)/pi - @test ForwardDiff.derivative(z -> PolyLog.reli(n, z), 0.0) == 1.0 - ChainRulesTestUtils.test_frule(PolyLog.reli, n, 0.0) - ChainRulesTestUtils.test_rrule(PolyLog.reli, n, float(pi)) - end end # value close to boundary between series 1 and 2 in arXiv:2010.09860 diff --git a/test/Li0.jl b/test/Li0.jl index 1929ab9..f16dbe0 100644 --- a/test/Li0.jl +++ b/test/Li0.jl @@ -50,11 +50,4 @@ # test value that causes overflow if squared @test PolyLog.li0(1e300 + 1im) ≈ -1.0 rtol=eps(Float64) @test PolyLog.li0(1.0 + 1e300im) ≈ -1.0 rtol=eps(Float64) - - # ForwardDiff Test - if isdefined(Base, :get_extension) - @test ForwardDiff.derivative(PolyLog.li0, float(pi)) == 1/(1 - pi)^2 - @test ForwardDiff.derivative(PolyLog.li0, 0.0) == 1 - @test ForwardDiff.derivative(PolyLog.li0, 1.0) == Inf - end end diff --git a/test/Li1.jl b/test/Li1.jl index ce2f629..d3dc8fd 100644 --- a/test/Li1.jl +++ b/test/Li1.jl @@ -59,12 +59,4 @@ # test value that causes overflow if squared @test PolyLog.li1(1e300 + 1im) ≈ -690.77552789821371 + 3.14159265358979im rtol=eps(Float64) @test PolyLog.li1(1.0 + 1e300im) ≈ -690.77552789821371 + 1.5707963267948966im rtol=eps(Float64) - - # ForwardDiff Test - if isdefined(Base, :get_extension) - @test ForwardDiff.derivative(PolyLog.reli1, float(pi)) == 1/(1 - pi) - @test ForwardDiff.derivative(PolyLog.reli1, 0.0) == 1.0 - ChainRulesTestUtils.test_frule(PolyLog.reli1, 0.0) - ChainRulesTestUtils.test_rrule(PolyLog.reli1, float(pi)) - end end diff --git a/test/Li2.jl b/test/Li2.jl index fc899f3..b1ff424 100644 --- a/test/Li2.jl +++ b/test/Li2.jl @@ -119,12 +119,4 @@ end # test value that causes overflow if squared @test PolyLog.li2(1e300 + 1im) ≈ -238582.12510339421 + 2170.13532372464im rtol=eps(Float64) @test PolyLog.li2(1.0 + 1e300im) ≈ -238585.82620504462 + 1085.06766186232im rtol=eps(Float64) - - # ForwardDiff Test - if isdefined(Base, :get_extension) - @test ForwardDiff.derivative(PolyLog.reli2, float(pi)) == PolyLog.reli1(pi)/pi - @test ForwardDiff.derivative(PolyLog.reli2, 0.0) == 1.0 - ChainRulesTestUtils.test_frule(PolyLog.reli2, 0.0) - ChainRulesTestUtils.test_rrule(PolyLog.reli2, float(pi)) - end end diff --git a/test/Li3.jl b/test/Li3.jl index e57fc0d..dc0dddb 100644 --- a/test/Li3.jl +++ b/test/Li3.jl @@ -60,12 +60,4 @@ # test value that causes overflow if squared @test PolyLog.li3(1e300 + 1im) ≈ -5.4934049431527088e7 + 749538.186928224im rtol=eps(Float64) @test PolyLog.li3(1.0 + 1e300im) ≈ -5.4936606061973454e7 + 374771.031356405im rtol=eps(Float64) - - # ForwardDiff Test - if isdefined(Base, :get_extension) - @test ForwardDiff.derivative(PolyLog.reli3, float(pi)) == PolyLog.reli2(pi)/pi - @test ForwardDiff.derivative(PolyLog.reli3, 0.0) == 1.0 - ChainRulesTestUtils.test_frule(PolyLog.reli3, 0.0) - ChainRulesTestUtils.test_rrule(PolyLog.reli3, float(pi)) - end end diff --git a/test/Li4.jl b/test/Li4.jl index 2c6db93..70e6aaf 100644 --- a/test/Li4.jl +++ b/test/Li4.jl @@ -52,12 +52,4 @@ # test value that causes overflow if squared @test PolyLog.li4(1e300 + 1im) ≈ -9.4863817894708364e9 + 1.725875455850714e8im rtol=eps(Float64) @test PolyLog.li4(1.0 + 1e300im) ≈ -9.4872648206269765e9 + 8.62951114411071e7im rtol=eps(Float64) - - # ForwardDiff Test - if isdefined(Base, :get_extension) - @test ForwardDiff.derivative(PolyLog.reli4, float(pi)) == PolyLog.reli3(pi)/pi - @test ForwardDiff.derivative(PolyLog.reli4, 0.0) == 1.0 - ChainRulesTestUtils.test_frule(PolyLog.reli4, 0.0) - ChainRulesTestUtils.test_rrule(PolyLog.reli4, float(pi)) - end end diff --git a/test/runtests.jl b/test/runtests.jl index b9d414f..65a1de8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,5 @@ using Test import PolyLog -if isdefined(Base, :get_extension) - import ForwardDiff - import ChainRulesTestUtils -end include("TestPrecision.jl") include("DataReader.jl") @@ -12,6 +8,7 @@ include("Digamma.jl") include("Dual.jl") include("Eta.jl") include("Factorial.jl") +include("ForwardDiff.jl") include("Harmonic.jl") include("Li0.jl") include("Li1.jl")