chewc

JAX breeding sim

Developer Guide

If you are new to using nbdev here are some useful pointers to get you started.

Install chewc in Development mode

# make sure chewc package is installed in development mode
$ pip install -e .

# make changes under nbs/ directory
# ...

# compile to have changes apply to chewc
$ nbdev_prepare

Usage

Installation

Install latest from the GitHub repository:

$ pip install git+https://github.com/cjGO/chewc.git

Documentation

Documentation can be found hosted on this GitHub repository’s pages.

How to use

import jax
import jax.numpy as jnp
from typing import Callable, Union

# Import the necessary classes and functions from your library
from chewc.sp import SimParam
from chewc.population import Population, quick_haplo
from chewc.trait import TraitCollection, add_trait_a
from chewc.phenotype import set_pheno
from chewc.cross import make_cross
from chewc.pipe import update_pop_values, select_and_cross

# --- 🧬 New High-Level Pipeline Functions ---




# --- 1. JAX Setup ---
key = jax.random.PRNGKey(42)

# --- 2-6. (Setup code remains the same as before) ---
# Define Genome Blueprint
n_chr, n_loci_per_chr, ploidy = 3, 100, 2
gen_map = jnp.array([jnp.linspace(0, 1, n_loci_per_chr) for _ in range(n_chr)])
centromeres = jnp.full(n_chr, 0.5)

# Instantiate SimParam
SP = SimParam(gen_map=gen_map, centromere=centromeres, ploidy=ploidy)

# Create Founder Population
key, pop_key = jax.random.split(key)
founder_pop = quick_haplo(key=pop_key, sim_param=SP, n_ind=100, inbred=False)
SP = SP.replace(founderPop=founder_pop)

# Add Single Additive Trait
trait_mean = 0
trait_var = 1
trait_h2 = .1

key, trait_key = jax.random.split(key)
SP_with_trait = add_trait_a(
    key=trait_key,
    sim_param=SP,
    n_qtl_per_chr=100,
    mean=jnp.array([trait_mean]),
    var=jnp.array([trait_var])
)

# Set Initial Phenotypes
key, pheno_key = jax.random.split(key)
h2 = jnp.array([trait_h2])
founder_pop_with_pheno = set_pheno(
    key=pheno_key,
    pop=founder_pop,
    traits=SP_with_trait.traits,
    ploidy=SP_with_trait.ploidy,
    h2=h2
)


pop_burn_in = founder_pop_with_pheno
sp_burn_in = SP_with_trait

# Selection parameters
n_parents_select = 5  # Total number of parents to select
n_progeny = 1000
burn_in_generations = 10

# --- 8. Burn-in Selection for 20 Generations (Simplified Loop) ---
print(f"\n--- Starting Burn-in Phenotypic Selection ({burn_in_generations} Generations) ---")

for gen in range(burn_in_generations):
    key, cross_key, update_key = jax.random.split(key, 3)

    # **SINGLE, HIGH-LEVEL CALL** to handle a full generation
    progeny_pop = select_and_cross(
        key=cross_key,
        pop=pop_burn_in,
        sp=sp_burn_in,
        n_parents=n_parents_select,
        n_crosses=n_progeny,
        use="pheno" # Select based on phenotype
    )
    
    # Update genetic and phenotypic values for the new generation
    pop_burn_in = update_pop_values(update_key, progeny_pop, sp_burn_in, h2=h2)

    # Track Progress
    mean_pheno = jnp.mean(pop_burn_in.pheno)
    print(f"Generation {gen + 1:2d}/{burn_in_generations} | Mean Phenotype: {mean_pheno:.4f}")

print("\n--- Burn-in Complete ---")
print(f"Final population state after {burn_in_generations} generations of selection:")
print(pop_burn_in)
WARNING:2025-07-19 12:53:52,707:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

--- Starting Burn-in Phenotypic Selection (10 Generations) ---
Generation  1/10 | Mean Phenotype: 0.4342
Generation  2/10 | Mean Phenotype: 2.2197
Generation  3/10 | Mean Phenotype: 4.3989
Generation  4/10 | Mean Phenotype: 4.8496
Generation  5/10 | Mean Phenotype: 5.4339
Generation  6/10 | Mean Phenotype: 5.7617
Generation  7/10 | Mean Phenotype: 6.2489
Generation  8/10 | Mean Phenotype: 6.4915
Generation  9/10 | Mean Phenotype: 6.6994
Generation 10/10 | Mean Phenotype: 6.8729

--- Burn-in Complete ---
Final population state after 10 generations of selection:
Population(nInd=1000, nTraits=1, has_ebv=No)