|
| 1 | +module DifferentiationInterfaceTestChairmarksExt |
| 2 | + |
| 3 | +using ADTypes: AbstractADType |
| 4 | +using Chairmarks: @be, Benchmark, Sample |
| 5 | +import DifferentiationInterface as DI |
| 6 | +using DifferentiationInterface: |
| 7 | + prepare_pushforward, |
| 8 | + prepare_pushforward_same_point, |
| 9 | + prepare!_pushforward, |
| 10 | + pushforward, |
| 11 | + pushforward!, |
| 12 | + value_and_pushforward, |
| 13 | + value_and_pushforward!, |
| 14 | + prepare_pullback, |
| 15 | + prepare_pullback_same_point, |
| 16 | + prepare!_pullback, |
| 17 | + pullback, |
| 18 | + pullback!, |
| 19 | + value_and_pullback, |
| 20 | + value_and_pullback!, |
| 21 | + prepare_derivative, |
| 22 | + prepare!_derivative, |
| 23 | + derivative, |
| 24 | + derivative!, |
| 25 | + value_and_derivative, |
| 26 | + value_and_derivative!, |
| 27 | + prepare_gradient, |
| 28 | + prepare!_gradient, |
| 29 | + gradient, |
| 30 | + gradient!, |
| 31 | + value_and_gradient, |
| 32 | + value_and_gradient!, |
| 33 | + prepare_jacobian, |
| 34 | + prepare!_jacobian, |
| 35 | + jacobian, |
| 36 | + jacobian!, |
| 37 | + value_and_jacobian, |
| 38 | + value_and_jacobian!, |
| 39 | + prepare_second_derivative, |
| 40 | + prepare!_second_derivative, |
| 41 | + second_derivative, |
| 42 | + second_derivative!, |
| 43 | + value_derivative_and_second_derivative, |
| 44 | + value_derivative_and_second_derivative!, |
| 45 | + prepare_hvp, |
| 46 | + prepare_hvp_same_point, |
| 47 | + prepare!_hvp, |
| 48 | + hvp, |
| 49 | + hvp!, |
| 50 | + gradient_and_hvp, |
| 51 | + gradient_and_hvp!, |
| 52 | + prepare_hessian, |
| 53 | + prepare!_hessian, |
| 54 | + hessian, |
| 55 | + hessian!, |
| 56 | + value_gradient_and_hessian, |
| 57 | + value_gradient_and_hessian! |
| 58 | +import DifferentiationInterfaceTest as DIT |
| 59 | +using DifferentiationInterfaceTest: |
| 60 | + ALL_OPS, |
| 61 | + CallCounter, CallsResult, DifferentiationBenchmarkDataRow, DifferentiationBenchmark, Scenario, |
| 62 | + mysimilar, reset_count! |
| 63 | +using Test |
| 64 | + |
| 65 | +function failed_bench() |
| 66 | + evals = 0.0 |
| 67 | + time = NaN |
| 68 | + allocs = NaN |
| 69 | + bytes = NaN |
| 70 | + gc_fraction = NaN |
| 71 | + compile_fraction = NaN |
| 72 | + recompile_fraction = NaN |
| 73 | + warmup = NaN |
| 74 | + checksum = NaN |
| 75 | + sample = Sample( |
| 76 | + evals, |
| 77 | + time, |
| 78 | + allocs, |
| 79 | + bytes, |
| 80 | + gc_fraction, |
| 81 | + compile_fraction, |
| 82 | + recompile_fraction, |
| 83 | + warmup, |
| 84 | + checksum, |
| 85 | + ) |
| 86 | + return Benchmark([sample]) |
| 87 | +end |
| 88 | + |
| 89 | +@kwdef struct BenchmarkResult |
| 90 | + prepared_valop::Benchmark = failed_bench() |
| 91 | + prepared_op::Benchmark = failed_bench() |
| 92 | + preparation::Benchmark = failed_bench() |
| 93 | + unprepared_valop::Benchmark = failed_bench() |
| 94 | + unprepared_op::Benchmark = failed_bench() |
| 95 | +end |
| 96 | + |
| 97 | + |
| 98 | +function record!( |
| 99 | + data::DifferentiationBenchmark; |
| 100 | + backend::AbstractADType, |
| 101 | + scenario::Scenario, |
| 102 | + operator::String, |
| 103 | + prepared::Union{Nothing, Bool}, |
| 104 | + bench::Benchmark, |
| 105 | + calls::Integer, |
| 106 | + aggregation, |
| 107 | + ) |
| 108 | + row = DifferentiationBenchmarkDataRow(; |
| 109 | + backend = backend, |
| 110 | + scenario = scenario, |
| 111 | + operator = Symbol(operator), |
| 112 | + prepared = prepared, |
| 113 | + calls = calls, |
| 114 | + samples = length(bench.samples), |
| 115 | + evals = Int(bench.samples[1].evals), |
| 116 | + time = aggregation(getfield.(bench.samples, :time)), |
| 117 | + allocs = aggregation(getfield.(bench.samples, :allocs)), |
| 118 | + bytes = aggregation(getfield.(bench.samples, :bytes)), |
| 119 | + gc_fraction = aggregation(getfield.(bench.samples, :gc_fraction)), |
| 120 | + compile_fraction = aggregation(getfield.(bench.samples, :compile_fraction)), |
| 121 | + ) |
| 122 | + return push!(data.rows, row) |
| 123 | +end |
| 124 | + |
| 125 | +include("benchmark_eval.jl") |
| 126 | + |
| 127 | +end |
0 commit comments