pipe
run_generation
run_generation (key:<function PRNGKey>, pop:chewc.population.Population, h2:jaxtyping.Float[Array,'nTraits'], n_parents:int, n_crosses:int, use_pheno_selection:bool, select_top_parents:bool, ploidy:int, gen_map:jax.Array, recomb_param_v:float, traits:chewc.trait.TraitCollection)
*The complete, JIT-compiled function for a single generation step.
This function encapsulates the entire process: selection, crossing, and phenotyping into a single computational graph that can be fused into a highly-efficient GPU/TPU kernel by XLA.
Args: key: A JAX random key for this entire generation. pop: The starting population for the generation. h2: Heritability for phenotyping the next generation. n_parents: How many parents to select (static). n_crosses: How many crosses to make (static). use_pheno_selection: Select on phenotype (True) or BV (False) (static). select_top_parents: Select top (True) or bottom (False) parents (static). ploidy: The ploidy of the individuals (static). gen_map: The genetic map from SimParam. recomb_param_v: The recombination interference parameter from SimParam. traits: The TraitCollection from SimParam.
Returns: The new Population object for the next generation, with all values computed.*
Type | Details | |
---|---|---|
key | PRNGKey | |
pop | Population | |
h2 | Float[Array, ‘nTraits’] | |
n_parents | int | |
n_crosses | int | |
use_pheno_selection | bool | Static parameters for the JIT compilation: |
select_top_parents | bool | |
ploidy | int | |
gen_map | Array | Pass static SimParam components directly |
recomb_param_v | float | |
traits | TraitCollection | |
Returns | Population |
update_pop_values
update_pop_values (key:<function PRNGKey>, pop:chewc.population.Population, sp:chewc.sp.SimParam, h2:jaxtyping.Float[Array,'nTraits'])
*(JIT-compatible) Calculates genetic and phenotypic values for a population.
This function was already JIT-compatible as it wraps the set_pheno
function. No changes were needed. It’s a key step after creating progeny.*
select_ind
select_ind (pop:chewc.population.Population, n_ind:int, use_pheno:bool=True, select_top:bool=True)
*(JIT-compatible) Selects indices of top/bottom individuals.
This function is now fully JIT-compatible. Python logic has been replaced
with `jax.lax.cond`. It no longer returns a sliced Population object, but
instead returns the *indices* of the selected individuals. This is a more
flexible, functional approach that allows the caller to decide how to use
the indices.
Args:
pop: The population to select from.
n_ind: The number of individuals to select (static argument).
use_pheno: If True, select on phenotype; if False, select on breeding value (bv).
(static argument).
select_top: If True, selects highest values; otherwise, lowest. (static argument).
Returns:
A 1D array of integer indices for the selected individuals.*