import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Tuple
from scipy.sparse import coo_matrix
# JAX's iterative solver and sparse matrix format
from jax.scipy.sparse.linalg import cg
import jax.experimental.sparse as jsparse
# Assume the 'chewc' library is installed
from chewc.structs import (
Population,
Trait,
GeneticMap,
quick_haplo,
add_trait
)
from chewc.pheno import calculate_phenotypes
from chewc.select import select_top_k
from chewc.cross import random_mating, cross_pair
from chewc.blup import *
if __name__ == "__main__":
# --- Parameters ---
N_FOUNDERS, N_SELECT, N_OFFSPRING = 200, 20, 200
N_CHR, N_LOCI = 5, 1000
SEED = 42
N_TRAITS = 2
h2_trait1, h2_trait2 = 0.5, 0.5
genetic_corr = 0.5
# --- Setup (Co)variance matrices ---
var_g1, var_g2 = h2_trait1, h2_trait2
cov_g12 = genetic_corr * jnp.sqrt(var_g1 * var_g2)
G0 = jnp.array([[var_g1, cov_g12], [cov_g12, var_g2]])
var_e1, var_e2 = 1 - h2_trait1, 1 - h2_trait2
R0 = jnp.diag(jnp.array([var_e1, var_e2]))
G0_inv, R0_inv = jnp.linalg.inv(G0), jnp.linalg.inv(R0)
# --- Population Simulation ---
print("--- Step 1-4: Simulating population and multi-trait phenotypes ---")
key = jax.random.PRNGKey(SEED)
key, pop_key, trait_key, pheno_key, mating_key, cross_key = jax.random.split(key, 6)
founder_pop, genetic_map = quick_haplo(key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI)
trait_architecture = add_trait(
key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
mean=jnp.array([100.0, 50.0]), var_a=jnp.array([var_g1, var_g2]), var_d=jnp.array([0.0, 3.0]), sigma=G0
)
founder_phenotypes, founder_tbvs = calculate_phenotypes(
key=pheno_key, population=founder_pop, trait=trait_architecture,
heritability=jnp.array([h2_trait1, h2_trait2])
)
selected_parents = select_top_k(founder_pop, founder_phenotypes[:, 0], k=N_SELECT)
pairings = random_mating(mating_key, n_parents=N_SELECT, n_crosses=N_OFFSPRING)
vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
offspring_keys = jax.random.split(cross_key, N_OFFSPRING)
offspring_geno, offspring_ibd = vmapped_cross(
offspring_keys, selected_parents.geno[pairings[:, 0]], selected_parents.geno[pairings[:, 1]],
selected_parents.ibd[pairings[:, 0]], selected_parents.ibd[pairings[:, 1]],
genetic_map, 10
)
new_meta = jnp.stack([
jnp.arange(N_OFFSPRING) + N_FOUNDERS,
selected_parents.meta[pairings[:, 0], 0],
selected_parents.meta[pairings[:, 1], 0],
jnp.full((N_OFFSPRING,), 1),
], axis=-1)
offspring_pop = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)
key, offspring_pheno_key = jax.random.split(key)
offspring_phenotypes, offspring_tbvs = calculate_phenotypes(
key=offspring_pheno_key, population=offspring_pop, trait=trait_architecture,
heritability=jnp.array([h2_trait1, h2_trait2])
)
all_phenotypes = jnp.concatenate([founder_phenotypes, offspring_phenotypes], axis=0)
print("--- Population simulation complete ---")
# --- ABLUP (Sparse, Iterative) ---
print("\n--- Performing Multi-Trait ABLUP (Sparse Iterative) ---")
full_pedigree = jnp.concatenate([founder_pop.meta, offspring_pop.meta], axis=0)
remapped_ped_np = remap_pedigree(full_pedigree)
A_inv_sparse = build_a_inverse_sparse(remapped_ped_np)
ablup_ebvs = solve_multi_trait_mme_iterative(
all_phenotypes, A_inv_sparse, G0_inv, R0_inv, n_traits=N_TRAITS
)
print("ABLUP calculation complete.")
# --- GBLUP (Iterative) ---
print("\n--- Performing Multi-Trait GBLUP (Iterative) ---")
all_geno = jnp.concatenate([founder_pop.geno, offspring_pop.geno], axis=0)
G_matrix = build_g_matrix(all_geno)
G_inv = jnp.linalg.inv(G_matrix + jnp.identity(G_matrix.shape[0]) * 1e-4)
gblup_gebvs = solve_multi_trait_mme_iterative(
all_phenotypes, G_inv, G0_inv, R0_inv, n_traits=N_TRAITS
)
print("GBLUP calculation complete.")
# --- Compare Results ---
print("\n--- Comparison of Results for Offspring ---")
offspring_ablup = ablup_ebvs[N_FOUNDERS:]
offspring_gblup = gblup_gebvs[N_FOUNDERS:]
acc_ablup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_ablup[:, 0])[0, 1]
acc_ablup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_ablup[:, 1])[0, 1]
acc_gblup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_gblup[:, 0])[0, 1]
acc_gblup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_gblup[:, 1])[0, 1]
print(f"\nABLUP Accuracy -> Trait 1: {acc_ablup_t1:.4f}, Trait 2: {acc_ablup_t2:.4f}")
print(f"GBLUP Accuracy -> Trait 1: {acc_gblup_t1:.4f}, Trait 2: {acc_gblup_t2:.4f}")
print('Note: both traits have the same heritabilities; but trait 2 has strong dominance effects, lowering accuracy')
print("\n{:<6} | {:>12} {:>12} | {:>12} {:>12} | {:>12} {:>12}".format(
"ID", "TBV T1", "TBV T2", "ABLUP T1", "ABLUP T2", "GBLUP T1", "GBLUP T2"))
print("-" * 88)
for i in range(10):
print("{:<6} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f}".format(
int(offspring_pop.meta[i, 0]),
offspring_tbvs[i, 0], offspring_tbvs[i, 1],
offspring_ablup[i, 0], offspring_ablup[i, 1],
offspring_gblup[i, 0], offspring_gblup[i, 1]
))