Metaohjelmointia Julialla: symbolinen derivointi

Esittelen seuraavaksi erään ihan todellisen käyttötapauksen, missä makrojen käytöstä voi olla todellista hyötyä. Kuvitellaan tilannetta, missä meillä on olemassa jokin analyyttinen funktio, yksinkertaisuuden vuoksi polynomi. Haluamme laskea tämän polynomin analyyttisiä derivaattoja nopeasti. Toteutetaan yksinkertainen CAS-laskin, joka osaa derivoida polynomeja. Käytän tässä pohjana John Myles Whiten blogikirjoitusta.

Juliassa Expr-tyyppi on ohjelmakoodin esitys, "ajamatonta koodia". Esimerkiksi,

Expr(:call, :+, 1, 1)
:(1 + 1)

Expr-tyypin suoritus tapahtuu eval-funktiolla:

eval(:(1 + 1))
2

Myös dump-komento on hyödyllinen:

dump(:(1 + 1))
Expr
  head: Symbol call
  args: Array{Any}((3,))
    1: Symbol +
    2: Int64 1
    3: Int64 1

Voimme rakentaa Julia-ohjelmakoodia ohjelmallisesti:

e = Expr(:call)
e.args = [:+, 1, 1]
eval(e)
2
myfunc = :(function hello() println("Hello, world!") end)
:(function hello()
      println("Hello, world")
  end)
eval(myfunc)
hello()
Hello, world!

FEMBasis-projektissa pitää laskea polynomeja ja niiden derivaattoja. Tehtävä ei ole kovin monimutkainen. Esimerkiksi, pitäisi voida laskea vaikkapa funktion

\[\frac{\partial\left(2w^{2}uv^{2}\right)}{\partial u}\]

arvo joillakin parametreilla $(u, v, w)$. Perinteinen tapa ratkaista asia on ottaa esille kynä ja paperia sekä ratkaista osittaisderivaatat käsin. Toinen tapa on laskea osittaisderivaatat jollakin symbolisella laskimella sekä copypastettaa vastaus koodiin. Kokeillaan nyt näistä vielä parannettua versiota, jossa itse asiassa laskemme osittaisderivaatan analyyttisesti ohjelman käännösvaiheessa, jotta se olisi nopeaa varsinaisessa ajovaiheessa.

Esimerkiksi mainittu funktio olisi Julian AST:ssä:

dump(:(2*w^2*u*v^2))
Expr
  head: Symbol call
  args: Array{Any}((5,))
    1: Symbol *
    2: Int64 2
    3: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol ^
        2: Symbol w
        3: Int64 2
    4: Symbol u
    5: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol ^
        2: Symbol v
        3: Int64 2

Derivoituna vaikkapa $u$:n suhteen haluaisimme

dump(:(2*w^2*v^2))
Expr
  head: Symbol call
  args: Array{Any}((4,))
    1: Symbol *
    2: Int64 2
    3: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol ^
        2: Symbol w
        3: Int64 2
    4: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol ^
        2: Symbol v
        3: Int64 2

Rakennellaan derivointia pala kerrallaan. Hyödynnetään Julian ns. multiple dispatchia, joka tunnetaan esimerkiksi Javassa tai C++:ssa method overloadingina. Eli voi olla samannimisiä funktioita, ja kutsuttavan funktio riippuu sen argumenteista. Valitaan funktion nimeksi differentiate, ja käyttö sillä tavalla, että differentiate(f, :x), missä f on jokin lauseke, palauttaa lausekkeen derivoituna x:n suhteen.

Ensimmäisenä vakion derivointi,

\[\frac{\mathrm{d}a}{\mathrm{d}x} = 0.\]

function differentiate(::Number, ::Symbol)
    return 0
end


@show differentiate(5, :x)
differentiate(5, :x) = 0

Seuraavaksi muuttujan derivointi:

function differentiate(s::Symbol, t::Symbol)
    return s == t ? 1 : 0
end

@show differentiate(:x, :x)
@show differentiate(:y, :x)
differentiate(:x, :x) = 1
differentiate(:y, :x) = 0

Seuraavaksi ketjusääntö. Toteutetaan yleisessä muodossa, missä muuttujia voi olla enemmänkin kuin kaksi. Esimerkiksi kolmelle x:stä riippuvalle muuttujalle

\[ \frac{\mathrm{d}}{\mathrm{d}x}\left(fgh\right)=\frac{\mathrm{d}f}{\mathrm{d}x}gh+f\frac{\mathrm{d}g}{\mathrm{d}x}h+fg\frac{\mathrm{d}h}{\mathrm{d}x}, \]

joka ohjelmointikielisenä syntaksina olisi ehkäpäkin

diff(*(f, g, h), x) = +(
    *(diff(f, x), g, h),
    *(f, diff(g, x), h),
    *(f, g, diff(h, x)))
function differentiate(::Type{Val{:*}}, ex::Expr, t::Symbol)
    @assert first(ex.args) == :*
    res_args = Any[:+]
    for i in 2:length(ex.args)
        new_args = copy(ex.args)
        new_args[i] = differentiate(ex.args[i], t)
        push!(res_args, Expr(:call, new_args...))
    end
    return Expr(:call, res_args...)
end

f1 = :(2*x)
f2 = :(x*u*w)
f3 = :(x*2)
f4 = :(2*2)
@show differentiate(Val{:*}, f1, :x)
@show differentiate(Val{:*}, f2, :x)
@show differentiate(Val{:*}, f3, :y)
@show differentiate(Val{:*}, f4, :x)
differentiate(Val{:*}, f1, :x) = :(0x + 2 * 1)
differentiate(Val{:*}, f2, :x) = :(1 * u * w + x * 0 * w + x * u * 0)
differentiate(Val{:*}, f3, :x) = :(0 * 2 + x * 0)
differentiate(Val{:*}, f4, :x) = :(0 * 2 + 2 * 0)

Ketjusäännön soveltamisen seurauksena tulee jonkin verran nollalla kertomisia, jotka voisi jonkinlaisella simplify-algoritmilla hoitaa pois. Asiaan palataan. Seuraavaksi potenssisääntö:

\[ \frac{\mathrm{d}}{\mathrm{d}x}\left(f\left(x\right)^{a}\right)=af\left(x\right)^{a-1}\frac{df}{dx} \]

function differentiate(::Type{Val{:^}}, ex::Expr, t::Symbol)
    op, f, a = ex.args
    return :($a * $f ^ ($a - 1) * $(differentiate(f, t)))
end

f1 = :(x^2)
f2 = :(x^2)
f3 = :(y^2)
f4 = :(y^2)
@show differentiate(Val{:^}, f1, :x)
@show differentiate(Val{:^}, f2, :y)
@show differentiate(Val{:^}, f3, :x)
@show differentiate(Val{:^}, f4, :y)
differentiate(Val{:^}, f1, :x) = :(2 * x ^ (2 - 1))
differentiate(Val{:^}, f2, :y) = 0
differentiate(Val{:^}, f3, :x) = 0
differentiate(Val{:^}, f4, :y) = :(2 * y ^ (2 - 1))

Lopuksi rakennetaan pääfunktio, joka käynnistää minkä tahansa Expr-muodossa annetun lausekkeen symbolisen derivoinnin:

function differentiate(ex::Expr, t::Symbol)
    if ex.head == :call
        return differentiate(Val{ex.args[1]}, ex, t)
    else
        return differentiate(ex.head, t)
    end
end

f = :(2*w^2*u*v^2)
@show differentiate(f, :u)
differentiate(f, :u) = :(0 * w ^ 2 * u * v ^ 2 + 2 * 0 * u * v ^ 2 + 2 * w ^ 2 * 1 * v ^ 2 + 2 * w ^ 2 * u * 0)

Voimme sieventää lauseketta rakentelemalla funktion simplify, joka käy Expr-tyypin läpi sekä poistaa sieltä asioita sopivien ehtojen täyttyessä.

Luku ja muuttuja eivät sievene

function simplify(a::Int64)
    return a
end

@show simplify(1)
simplify(1) = 1
function simplify(s::Symbol)
    return s
end

@show simplify(:x)
simplify(:x) = :x

Kertolaskuissa \(f \cdot 0 = 0\), \(f \cdot 1 \cdot g = f \cdot g\), ja \(f \cdot 1 = f\):

function simplify(::Type{Val{:*}}, ex::Expr)
    args = simplify.(ex.args[2:end])
    0 in args && return 0
    filter!(k -> !(isa(k, Number) && k == 1), args)
    length(args) == 1 && return first(args)
    return Expr(:call, :*, args...)
end

f = :(a * 0)
g = :(a * 1)
@show simplify(Val{:*}, f)
@show simplify(Val{:*}, g)
simplify(Val{:*}, f) = 0
simplify(Val{:*}, g) = :a

Yhteenlaskuissa, nollatermit voidaan poistaa sekä jos yhteenlaskettavia on vain yksi määrä, ei ole mitään yhteenlaskettavaa: \(a + 0 + b = a + b\) ja \(a + 0 = a\).

function simplify(::Type{Val{:+}}, ex::Expr)
    args = simplify.(ex.args[2:end])
    filter!(k -> !isa(k, Number) || k != 0, args)
    length(args) == 1 && return first(args)
    return Expr(:call, :+, args...)
end

f = :(a + b + 0)
g = :(a + 0)
@show simplify(Val{:+}, f)
@show simplify(Val{:+}, g)
simplify(Val{:+}, f) = :(a + b)
simplify(Val{:+}, g) = :a

Sama vähennyslaskuissa

function simplify(::Type{Val{:-}}, ex::Expr)
    args = simplify.(ex.args[2:end])
    filter!(k -> !isa(k, Number) || k != 0, args)
    length(args) == 1 && return first(args)
    return Expr(:call, :-, args...)
end

Potenssilaskussa ^(a, b) voidaan yrittää sieventää molempia puolia

function simplify(::Type{Val{:^}}, ex::Expr)
    args = simplify.(ex.args[2:end])
    return Expr(:call, :^, args...)
end

f = :((2+0)^(1*a))
@show simplify(Val{:^}, f)
simplify(Val{:^}, f) = :(2 ^ a)

Lopuksi jälleen pääfunktio, jolla lausekkeen sieventäminen aloitetaan

function simplify(ex::Expr)
    return ex.head == :call ? simplify(Val{ex.args[1]}, ex) : ex
end

f = :(2*w^2*u*v^2)
@show simplify(differentiate(f, :u))
simplify(differentiate(f, :u)) = :(2 * w ^ 2 * v ^ 2)

Tarvitsemme vielä lisäksi jonkinlaisen subs-komennon, jolla symbolin tilalle voidaan sijoittaa numero:

function subs(a::Int, ::Pair{Symbol, Int}...)
    return a
end

function subs(s::Symbol, args::Pair{Symbol, Int}...)
    for (k, v) in args
        s == k && return v
    end
    return s
end

function subs(ex::Expr, args::Pair{Symbol, Int}...)
    ne = copy(ex)
    for i=2:length(ne.args)
        ne.args[i] = subs(ne.args[i], args...)
    end
    return ne
end

d = simplify(differentiate(:(2*w^2*u*v^2), :u))
@show d
@show subs(d, :w => 3)
d = :(2 * w ^ 2 * v ^ 2)
subs(d, :w => 3) = :(2 * 3 ^ 2 * v ^ 2)

Loput derivointisäännöt voitaisiin kirjoittaa vastaavanlaisilla säännöillä. Nyt derivaattorimme kuitenkin implementoi kaiken tarvittavan FEMBasis.jl:n tarpeisiin. Mikäli jotakin puuttuu, tulee vähintäänkin virheilmoitus, esimerkiksi

f = :(1 + x)
@show differentiate(f, :x)
ERROR: LoadError: MethodError: no method matching differentiate(::Type{Val{:+}}, ::Expr, ::Symbol)
f = :(sin(x))
@show differentiate(f, :x)
ERROR: LoadError: MethodError: no method matching differentiate(::Type{Val{:sin}}, ::Expr, ::Symbol)

Sitten se mielenkiintoinen kysymys. Nyt kun kehitelty oma symbolinen derivointi, niin kuinka sitä voisi käyttää tehokkaasti? Ja nyt makrot esittävät hyödyllisyytensä.

Saamme kyllä ajettua koodia ilmankin, rakentamalla korkeamman asteen funktion, joka palauttaa funktion derivaatan. Tämä olisi se todennäköinen ratkaisu, mikäli käyttäisi Pythonia ja SymPyä.

function diff(f, x)
    function diff_(u, v, w)
        df = simplify(differentiate(f, x))
        res = subs(df, :u => u, :v => v, :w => w)
        return eval(res)
    end
end

df1 = diff(:(2*w^2*u*v^2), :u)
@show df1(1, 2, 3)
df1(1, 2, 3) = 72

Koodi toimii, mutta joka kerta, kun funktiota kutsutaan, joudutaan derivoimaan funktio uudelleen. Jos laitetaan "käsin laskettu" derivaatta ja yllä mainittu funktio testipenkkiin, niin tulos on aika karu.

df3 = (u, v, w) -> 2 * w^2 * v^2

using BenchmarkTools
@btime df1(1, 2, 3)
@btime df3(1, 2, 3)
231.134 μs (600 allocations: 33.77 KiB)
 12.732 ns (0 allocations: 0 bytes)

Nopeusero on karkeasti laskettuna n. 20000-kertainen. On kyllä olemassa eri keinoja, joilla tehokkuutta voidaan parantaa. Esimerkiksi derivoinnin voi ottaa ulos sisemmästä funktiosta joka parantaa tilannetta hieman. Ratkaisu on joka tapauksessa kohtalaisen monimutkainen eikä lähelläkään sitä nopeutta, mitä se voisi olla mikäli funktion derivaatta olisi suoraan naputeltu lambda-lausekkeeseen.

Jos tiedetään jo ennen ohjelman suorittamista, että tarvitsemme kyseisen polynomin derivaatan arvoja, parasta olisi silloin laskea derivaatat jo ohjelman esikäsittelyvaiheessa, jolloin nopeat derivaatan laskennat onnistuvat suoraviivaisesti ohjelman ajon aikana. Makroilla voidaan saavuttaa ratkaisu, jossa ei ole mitään ylimääräistä laskennallista monimutkaisuutta siihen nähden, että derivaattafunktio olisi käsin kirjoitettu.

Koska makro kirjoitetaan ennen pääohjelman suorittamista, voimme siirtää kaiken symbolisen laskennan esikäsittelyyn ja saada derivaattafunktion, jonka nopeus on yhtä nopea, kuin se olisi kirjoitettu manuaalisesti. Makron pitää siis aivan konkreettisesti, palauttaa (u, v, w) -> 2 * w^2 * v^2, kun sille annetaan argumentteina (:(2*w^2*u*v^2), :u). Makro tässä tapauksessa on, mikäli halutaan säilyttää syntaksi samankaltaisena kuin funktiokutsussa:

macro diff(f, x)
    df = simplify(differentiate(eval(f), eval(x)))
    return :((u, v, w) -> $df)
end

Tässä kaikki monimutkaisuus on siirretty esikäsittelyyn ja lopputuloksena palautetaan funktio, joka on identtinen sen kanssa mitä käsin kirjoitettaisiin. Asian voi varmentaa @macroexpand-komennolla.

@macroexpand @diff(:(2*w^2*u*v^2), :u)
:((var"#34#u", var"#35#v", var"#36#w")->begin
          2 * var"#36#w" ^ 2 * var"#35#v" ^ 2
      end)

Julia käyttää sisäisesti gensym()-funktiota uudelleennimeämään muuttujat siten, ettei yhteensopivuusongelmia muun ohjelmakoodin kanssa synny. Mutta uusia muuttujan nimiä hieman katselemalla näkee, että kyseinen makro kyllä palauttaa funktion (u, v, w) -> 2 * w^2 * v^2, kuten pitääkin.

Yhteenvetomaisesti eri tekniikoilla derivointi:

df1 =  diff(:(2*w^2*u*v^2), :u)
df2 = @diff(:(2*w^2*u*v^2), :u)
df3 = (u, v, w) -> 2 * w^2 * v^2

@btime df1(1, 2, 3)
@btime df2(1, 2, 3)
@btime df3(1, 2, 3)
  266.269 μs (651 allocations: 35.36 KiB)
  12.735 ns (0 allocations: 0 bytes)
  12.686 ns (0 allocations: 0 bytes)

Makrot ovat hieman funktioita omituisempia rakenteita. Itse pyrin niitä henkilökohtaisesti hieman jopa välttelemään. Jos makroissa tulee bugeja, niiden debuggaaminen on hankalampaa ja konsepti on muutenkin vähemmän tunnettu. Mutta niille kyllä löytyy omat käyttökohteensa. Tässä nähtiin eräs potentiaalinen käyttökohde. Oikeissa paikoissa käytettynä ne ovat varsin tehokas lisätyökalu. Viimeistään siinä vaiheessa, kun tuntuu siltä, että haluaisi kirjoittaa jonkinlaisen skriptin, joka kirjoittaa Julia-koodia, tietää että ratkaisu löytyy makrosta.