chewc

JAX breeding sim

Developer Guide

If you are new to using nbdev here are some useful pointers to get you started.

Install chewc in Development mode

# make sure chewc package is installed in development mode
$ pip install -e .

# make changes under nbs/ directory
# ...

# compile to have changes apply to chewc
$ nbdev_prepare

Usage

Installation

Install latest from the GitHub repository:

$ pip install git+https://github.com/cjGO/chewc.git

Documentation

Documentation can be found hosted on this GitHub repository’s pages.

Multi trait simulation example

This code block sets up the simulation parameters, creates a founder population in linkage equilbrium across loci, and simulates 2 correlated traits.

Then a single generation is simulated from the founder population and ABLUP and GBLUP models are fitted for the offspring.

Both the breeding simulation and prediction models are running on JAX.

How to use

import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Tuple
from scipy.sparse import coo_matrix

# JAX's iterative solver and sparse matrix format
from jax.scipy.sparse.linalg import cg
import jax.experimental.sparse as jsparse

# Assume the 'chewc' library is installed
from chewc.structs import (
    Population,
    Trait,
    GeneticMap,
    quick_haplo,
    add_trait
)
from chewc.pheno import calculate_phenotypes
from chewc.select import select_top_k
from chewc.cross import random_mating, cross_pair
from chewc.blup import *


if __name__ == "__main__":
    # --- Parameters ---
    N_FOUNDERS, N_SELECT, N_OFFSPRING = 200, 20, 200
    N_CHR, N_LOCI = 5, 1000
    SEED = 42
    N_TRAITS = 2
    h2_trait1, h2_trait2 = 0.5, 0.5
    genetic_corr = 0.5
    
    # --- Setup (Co)variance matrices ---
    var_g1, var_g2 = h2_trait1, h2_trait2
    cov_g12 = genetic_corr * jnp.sqrt(var_g1 * var_g2)
    G0 = jnp.array([[var_g1, cov_g12], [cov_g12, var_g2]])
    
    var_e1, var_e2 = 1 - h2_trait1, 1 - h2_trait2
    R0 = jnp.diag(jnp.array([var_e1, var_e2]))
    
    G0_inv, R0_inv = jnp.linalg.inv(G0), jnp.linalg.inv(R0)

    # --- Population Simulation ---
    print("--- Step 1-4: Simulating population and multi-trait phenotypes ---")
    key = jax.random.PRNGKey(SEED)
    key, pop_key, trait_key, pheno_key, mating_key, cross_key = jax.random.split(key, 6)
    
    founder_pop, genetic_map = quick_haplo(key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI)
    
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0]), var_a=jnp.array([var_g1, var_g2]), var_d=jnp.array([0.0, 3.0]), sigma=G0
    )
    
    founder_phenotypes, founder_tbvs = calculate_phenotypes(
        key=pheno_key, population=founder_pop, trait=trait_architecture,
        heritability=jnp.array([h2_trait1, h2_trait2])
    )
    
    selected_parents = select_top_k(founder_pop, founder_phenotypes[:, 0], k=N_SELECT)
    pairings = random_mating(mating_key, n_parents=N_SELECT, n_crosses=N_OFFSPRING)
    
    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, N_OFFSPRING)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys, selected_parents.geno[pairings[:, 0]], selected_parents.geno[pairings[:, 1]],
        selected_parents.ibd[pairings[:, 0]], selected_parents.ibd[pairings[:, 1]],
        genetic_map, 10
    )
    
    new_meta = jnp.stack([
        jnp.arange(N_OFFSPRING) + N_FOUNDERS,
        selected_parents.meta[pairings[:, 0], 0],
        selected_parents.meta[pairings[:, 1], 0],
        jnp.full((N_OFFSPRING,), 1),
    ], axis=-1)
    offspring_pop = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)
    
    key, offspring_pheno_key = jax.random.split(key)
    offspring_phenotypes, offspring_tbvs = calculate_phenotypes(
        key=offspring_pheno_key, population=offspring_pop, trait=trait_architecture,
        heritability=jnp.array([h2_trait1, h2_trait2])
    )
    
    all_phenotypes = jnp.concatenate([founder_phenotypes, offspring_phenotypes], axis=0)
    print("--- Population simulation complete ---")

    # --- ABLUP (Sparse, Iterative) ---
    print("\n--- Performing Multi-Trait ABLUP (Sparse Iterative) ---")
    full_pedigree = jnp.concatenate([founder_pop.meta, offspring_pop.meta], axis=0)
    remapped_ped_np = remap_pedigree(full_pedigree)
    
    A_inv_sparse = build_a_inverse_sparse(remapped_ped_np)
    ablup_ebvs = solve_multi_trait_mme_iterative(
        all_phenotypes, A_inv_sparse, G0_inv, R0_inv, n_traits=N_TRAITS
    )
    print("ABLUP calculation complete.")
    
    # --- GBLUP (Iterative) ---
    print("\n--- Performing Multi-Trait GBLUP (Iterative) ---")
    all_geno = jnp.concatenate([founder_pop.geno, offspring_pop.geno], axis=0)
    G_matrix = build_g_matrix(all_geno)
    G_inv = jnp.linalg.inv(G_matrix + jnp.identity(G_matrix.shape[0]) * 1e-4)
    
    gblup_gebvs = solve_multi_trait_mme_iterative(
        all_phenotypes, G_inv, G0_inv, R0_inv, n_traits=N_TRAITS
    )
    print("GBLUP calculation complete.")

    # --- Compare Results ---
    print("\n--- Comparison of Results for Offspring ---")
    offspring_ablup = ablup_ebvs[N_FOUNDERS:]
    offspring_gblup = gblup_gebvs[N_FOUNDERS:]

    acc_ablup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_ablup[:, 0])[0, 1]
    acc_ablup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_ablup[:, 1])[0, 1]
    acc_gblup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_gblup[:, 0])[0, 1]
    acc_gblup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_gblup[:, 1])[0, 1]

    print(f"\nABLUP Accuracy -> Trait 1: {acc_ablup_t1:.4f}, Trait 2: {acc_ablup_t2:.4f}")
    print(f"GBLUP Accuracy -> Trait 1: {acc_gblup_t1:.4f}, Trait 2: {acc_gblup_t2:.4f}")
    print('Note: both traits have the same heritabilities; but trait 2 has strong dominance effects, lowering accuracy')

    print("\n{:<6} | {:>12} {:>12} | {:>12} {:>12} | {:>12} {:>12}".format(
        "ID", "TBV T1", "TBV T2", "ABLUP T1", "ABLUP T2", "GBLUP T1", "GBLUP T2"))
    print("-" * 88)
    for i in range(10):
        print("{:<6} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f}".format(
            int(offspring_pop.meta[i, 0]),
            offspring_tbvs[i, 0], offspring_tbvs[i, 1],
            offspring_ablup[i, 0], offspring_ablup[i, 1],
            offspring_gblup[i, 0], offspring_gblup[i, 1]
        ))
--- Step 1-4: Simulating population and multi-trait phenotypes ---
--- Population simulation complete ---

--- Performing Multi-Trait ABLUP (Sparse Iterative) ---
[DEBUG] Sparse A_inv successfully created with 1406 non-zero elements.
[DEBUG] First 5 data points of A_inv: [1. 1. 1. 1. 1.]
ABLUP calculation complete.

--- Performing Multi-Trait GBLUP (Iterative) ---
GBLUP calculation complete.

--- Comparison of Results for Offspring ---

ABLUP Accuracy -> Trait 1: 0.7611, Trait 2: 0.5660
GBLUP Accuracy -> Trait 1: 0.8045, Trait 2: 0.5809
Note: both traits have the same heritabilities; but trait 2 has strong dominance effects, lowering accuracy

ID     |       TBV T1       TBV T2 |     ABLUP T1     ABLUP T2 |     GBLUP T1     GBLUP T2
----------------------------------------------------------------------------------------
200    |        0.847        1.461 |        0.322        1.116 |       -0.352        0.364
201    |        1.682        2.124 |        0.722        0.735 |        0.215        0.453
202    |        2.488        1.058 |        1.688        1.321 |        1.163        1.017
203    |        2.594        1.706 |        0.780       -0.772 |        0.553       -1.136
204    |        2.269        0.813 |        1.850        1.366 |        1.186        1.020
205    |        2.632        1.291 |        1.492        1.700 |        0.939        0.960
206    |        2.682        2.317 |        1.919        2.742 |        1.315        2.362
207    |        3.482        2.645 |        2.379        3.829 |        1.891        3.462
208    |        1.728        1.467 |        0.913        0.271 |        0.322       -0.260
209    |        2.602        2.125 |        0.888        0.409 |        0.511        0.211