If you are new to using nbdev here are some useful pointers to get you started.
Install chewc in Development mode
# make sure chewc package is installed in development mode$ pip install -e .# make changes under nbs/ directory# ...# compile to have changes apply to chewc$ nbdev_prepare
Documentation can be found hosted on this GitHub repository’s pages.
How to use
import jaximport jax.numpy as jnpfrom typing import Callable, Union# Import the necessary classes and functions from your libraryfrom chewc.sp import SimParamfrom chewc.population import Population, quick_haplofrom chewc.trait import TraitCollection, add_trait_afrom chewc.phenotype import set_phenofrom chewc.cross import make_crossfrom chewc.pipe import update_pop_values, select_and_cross# --- 🧬 New High-Level Pipeline Functions ---# --- 1. JAX Setup ---key = jax.random.PRNGKey(42)# --- 2-6. (Setup code remains the same as before) ---# Define Genome Blueprintn_chr, n_loci_per_chr, ploidy =3, 100, 2gen_map = jnp.array([jnp.linspace(0, 1, n_loci_per_chr) for _ inrange(n_chr)])centromeres = jnp.full(n_chr, 0.5)# Instantiate SimParamSP = SimParam(gen_map=gen_map, centromere=centromeres, ploidy=ploidy)# Create Founder Populationkey, pop_key = jax.random.split(key)founder_pop = quick_haplo(key=pop_key, sim_param=SP, n_ind=100, inbred=False)SP = SP.replace(founderPop=founder_pop)# Add Single Additive Traittrait_mean =0trait_var =1trait_h2 =.1key, trait_key = jax.random.split(key)SP_with_trait = add_trait_a( key=trait_key, sim_param=SP, n_qtl_per_chr=100, mean=jnp.array([trait_mean]), var=jnp.array([trait_var]))# Set Initial Phenotypeskey, pheno_key = jax.random.split(key)h2 = jnp.array([trait_h2])founder_pop_with_pheno = set_pheno( key=pheno_key, pop=founder_pop, traits=SP_with_trait.traits, ploidy=SP_with_trait.ploidy, h2=h2)pop_burn_in = founder_pop_with_phenosp_burn_in = SP_with_trait# Selection parametersn_parents_select =5# Total number of parents to selectn_progeny =1000burn_in_generations =10# --- 8. Burn-in Selection for 20 Generations (Simplified Loop) ---print(f"\n--- Starting Burn-in Phenotypic Selection ({burn_in_generations} Generations) ---")for gen inrange(burn_in_generations): key, cross_key, update_key = jax.random.split(key, 3)# **SINGLE, HIGH-LEVEL CALL** to handle a full generation progeny_pop = select_and_cross( key=cross_key, pop=pop_burn_in, sp=sp_burn_in, n_parents=n_parents_select, n_crosses=n_progeny, use="pheno"# Select based on phenotype )# Update genetic and phenotypic values for the new generation pop_burn_in = update_pop_values(update_key, progeny_pop, sp_burn_in, h2=h2)# Track Progress mean_pheno = jnp.mean(pop_burn_in.pheno)print(f"Generation {gen +1:2d}/{burn_in_generations} | Mean Phenotype: {mean_pheno:.4f}")print("\n--- Burn-in Complete ---")print(f"Final population state after {burn_in_generations} generations of selection:")print(pop_burn_in)
WARNING:2025-07-19 12:53:52,707:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
--- Starting Burn-in Phenotypic Selection (10 Generations) ---
Generation 1/10 | Mean Phenotype: 0.4342
Generation 2/10 | Mean Phenotype: 2.2197
Generation 3/10 | Mean Phenotype: 4.3989
Generation 4/10 | Mean Phenotype: 4.8496
Generation 5/10 | Mean Phenotype: 5.4339
Generation 6/10 | Mean Phenotype: 5.7617
Generation 7/10 | Mean Phenotype: 6.2489
Generation 8/10 | Mean Phenotype: 6.4915
Generation 9/10 | Mean Phenotype: 6.6994
Generation 10/10 | Mean Phenotype: 6.8729
--- Burn-in Complete ---
Final population state after 10 generations of selection:
Population(nInd=1000, nTraits=1, has_ebv=No)