core
Fill in a module description here
set_pheno
set_pheno (key:<function PRNGKey>, pop:chewc.population.Population, traits:chewc.trait.TraitCollection, ploidy:int, h2:jaxtyping.Float[Array,'nTraits'], cor_e:Optional[jaxtyping.Float[Array,'nTraitsnTraits']]=None)
*Sets phenotypes for a population based on its genetic values and a specified heritability. This is a high-performance, JIT-compiled function.
— JAX Implementation Notes —
This function serves as a JIT-compatible wrapper for the core logic in _set_pheno_internal
. The ploidy
argument, being a standard Python integer that influences array shapes, must be treated as a “static” argument for the JIT compiler.
We achieve this using functools.partial
. A new function is created on-the-fly where ploidy
is a fixed, “baked-in” value. This new function, which only contains JAX-traceable arguments, is then JIT-compiled and executed. This pattern ensures that JAX does not need to re-compile the function unless the value of ploidy
changes.*