Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Overloading methods for parametrized struct

Short Version

I have a struct like

@kwdef mutable struct Params{B, C}
    a::Float64 = 42
end

where B and C are symbols that will determine which function to use out of many functions that are overloaded for different combinations of symbols. The valid symbols would be for example :constant and :exponential. I want to write a function that will be invoked if one of the symbols matches the corresponding symbol in the parametrized struct, and it should also be invoked regardless of the other symbol. So as an example,

p = Params{:constant, :gaussian}()
foo(p::Params{:constant, <:Symbol}) = 42
foo(p) # should return 42, but throws an error

How can I do this?

Context

I am implementing a dynamic rule (to be used with DynamicalSystems.jl) for which one term of the rule is computed using the PDF of the gaussian distribution, a reciprocal function, or is constant (see supplementary information as well as figure 3 of "Quantitative modeling of the terminal differentiation of B cells and mechanisms of lymphomagenesis"). Using a simplified model, see the below code as an example of what I want to do:

using UnPack
using DynamicalSystems
using Distributions

@kwdef mutable struct Params{B, C}
    mu_p    = 10e-6
    sigma_p = 9
    mu_b    = 2
    sigma_b = 100
    mu_r    = 0.1
    sigma_r = 2.6
    
    bcr_0   = 0.05
    cd_0    = 0.025
end

function germinal_center_regulation_rule(u, params, t)
    @unpack mu_p, sigma_p, mu_b, sigma_b, mu_r, sigma_r, 
    @unpack bcr0, cd0 = params
    p, b, r = u
    #######IMPORTANT PART############
    bcr = compute_bcr(;u, params, t)
    cd40 = compute_cd40(;u, params, t)
    #################################
    pdot = mu_p + sigma_p/b
    bdot = mu_p + sigma_p/b - bcr*b
    rdot = mu_r * sigma_r/r
    return SVector(pdot, bdot, rdot)
    
# Different methods for different parameterized struct
compute_bcr(;u, params::Params{:constant, <:Symbol}, t) = 15
compute_bcr(;u, params::Params{:gaussian, <:Symbol}, t) = pdf(Normal(), t)
compute_cd40(;u, params::Params{<:Symbol, :reciprocal}, t) = params[:bcr0]/u[2]
compute_cd40(;u, params::Params{<:Symbol, :gaussian}, t) = pdf(Normal(), t)

# Example usage
p_constant_bcr_gaussian_cd40 = Params{:constant, :gaussian}()
u0 = [0.2, 5.0, 0.2]
mixed_ds = CoupledODEs(
  germinal_center_regulation_rule, u0, p_constant_bcr_gaussian_cd40)
total_time = 200
X, t = trajectory(ds, total_time)

p_gaussian_bcr_gaussian_cd40 = Params{:gaussian, :gaussian}()
gaussian_ds = CoupledODEs(
  germinal_center_regulation_rule, u0, p_gaussian_bcr_gaussian_cd40)
total_time = 200
X, t = trajectory(ds, total_time)
like image 795
Jared Avatar asked Mar 06 '26 18:03

Jared


1 Answers

The issue is that T<:U indicates that T is a subtype of U. But in your case, :gaussian is not a subtype of Symbol, but rather is a Symbol. You can fix this with

foo(p::Params{:constant, S}) where {S} = 42

If you must ensure that S is a Symbol, it's best to do so in Params’s constructor.

julia> @kwdef mutable struct Params{B, C}
           a::Float64 = 42
           function Params{B, C}(a) where {B, C}
               if !(B isa Symbol && C isa Symbol)
                   error("Params generics must be Symbols")
               end
               return new{B, C}(a)
           end
       end

julia> Params{:a, :b}(2)
Params{:a, :b}(2.0)

julia> Params{:a, 1}(2)
ERROR: Params generics must be Symbols
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] Params{:a, 1}(a::Int64)
   @ Main ./REPL[49]:5
 [3] top-level scope
   @ REPL[51]:1

like image 156
BallpointBen Avatar answered Mar 08 '26 19:03

BallpointBen