import jax
import jax.numpy as jnp
import numpy as np
# Import core chewc components
from chewc.structs import add_trait
from chewc.burnin import run_burnin
# UPDATED: Import the pipeline runner from chewc.pipe
from chewc.pipe import run_simulation_cycles
# --- 1. Experiment Parameters ---
N_POP = 150 # Constant population size
N_ENVIRONMENTS = 1 # Single Location/Trait
N_CHR = 5
N_LOCI = 1000
N_QTL = 50
SEED = 123
# Simulation Settings
BURN_IN_GENS = 50
MAX_CROSSOVERS = 10
# Selection Settings
N_SELECTION_GENS = 50
N_SELECT = 20 # Select top 20 parents
N_OFFSPRING = 150 # Next gen size
# --- 2. Burn-in (Initialize & Establish LD) ---
print(f"--- Setting up Single-Trait Experiment ---")
print(f"Population Size: {N_POP}")
print(f"Environments: {N_ENVIRONMENTS} (Single Trait)")
print(f"\n[Phase 1] Running {BURN_IN_GENS} generations of burn-in...")
key = jax.random.PRNGKey(SEED)
key, burnin_key, trait_key = jax.random.split(key, 3)
# Single call to handle initialization and burn-in
stable_state, final_ld, genetic_map = run_burnin(
key=burnin_key,
n_gens=BURN_IN_GENS,
n_pop=N_POP,
n_chr=N_CHR,
n_loci=N_LOCI,
max_crossovers=MAX_CROSSOVERS
)
print(f"Burn-in complete at Generation {stable_state.generation}")
print(f"Mean Adjacent LD (r^2) per chromosome: {final_ld}")
# --- 3. Define Single Trait Architecture ---
print(f"\n[Phase 2] Defining Single Trait Architecture...")
# 1x1 Correlation matrix (scalar 1.0)
genetic_correlation = jnp.eye(N_ENVIRONMENTS)
trait_arch = add_trait(
key=trait_key,
founder_pop=stable_state.population,
n_qtl_per_chr=N_QTL,
mean=jnp.zeros(N_ENVIRONMENTS),
var_a=jnp.ones(N_ENVIRONMENTS),
var_d=jnp.zeros(N_ENVIRONMENTS),
sigma=genetic_correlation
)
# Single heritability value
HERITABILITIES = jnp.array([0.5])
# --- 4. Execute Selection Loop ---
print(f"\n--- Starting {N_SELECTION_GENS} Generations of Selection ---")
# REFACTORED: Use run_simulation_cycles from chewc.pipe
# This handles the partial binding and lax.scan internally
final_state, history = run_simulation_cycles(
initial_state=stable_state,
trait=trait_arch,
genetic_map=genetic_map,
heritabilities=HERITABILITIES,
n_cycles=N_SELECTION_GENS,
n_select=N_SELECT,
n_offspring=N_OFFSPRING,
max_crossovers=MAX_CROSSOVERS
)
# --- 5. Results Analysis ---
print("\nGeneration | Mean TBV (Genetic Gain) | Mean Phenotype")
print("---------------------------------------------------")
metrics_history = np.array(history)
for i in range(N_SELECTION_GENS):
gen = stable_state.generation + i + 1
tbv = metrics_history[i, 0]
pheno = metrics_history[i, 1]
print(f"Gen {gen:<3} | {tbv:<23.4f} | {pheno:.4f}")
total_gain = metrics_history[-1, 0] - metrics_history[0, 0]
print(f"\nTotal Genetic Gain: {total_gain:.4f}")/home/glect/.local/lib/python3.10/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.
warnings.warn("Unable to import Axes3D. This may be due to multiple versions of "
WARNING:2025-11-26 16:51:28,127:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
--- Setting up Single-Trait Experiment ---
Population Size: 150
Environments: 1 (Single Trait)
[Phase 1] Running 50 generations of burn-in...
Burn-in complete at Generation 50
Mean Adjacent LD (r^2) per chromosome: [0.15110835 0.14266446 0.14125235 0.14965521 0.1522844 ]
[Phase 2] Defining Single Trait Architecture...
--- Starting 50 Generations of Selection ---
Generation | Mean TBV (Genetic Gain) | Mean Phenotype
---------------------------------------------------
Gen 51 | 0.3050 | 0.0893
Gen 52 | 1.5707 | 1.3299
Gen 53 | 2.5418 | 2.2419
Gen 54 | 3.4875 | 3.1454
Gen 55 | 4.3080 | 3.9606
Gen 56 | 5.1795 | 4.8874
Gen 57 | 5.7102 | 5.3484
Gen 58 | 6.5932 | 6.2297
Gen 59 | 7.3113 | 7.1080
Gen 60 | 7.8562 | 7.5016
Gen 61 | 8.4342 | 8.0798
Gen 62 | 8.9574 | 8.6164
Gen 63 | 9.4117 | 9.1064
Gen 64 | 9.9187 | 9.5113
Gen 65 | 10.3476 | 10.0543
Gen 66 | 10.7954 | 10.5121
Gen 67 | 11.1176 | 10.8407
Gen 68 | 11.4818 | 11.2164
Gen 69 | 11.8435 | 11.5521
Gen 70 | 12.1928 | 11.9043
Gen 71 | 12.4904 | 12.1992
Gen 72 | 12.7797 | 12.4851
Gen 73 | 13.0009 | 12.7128
Gen 74 | 13.1698 | 12.8855
Gen 75 | 13.4130 | 13.0821
Gen 76 | 13.6155 | 13.3147
Gen 77 | 13.8012 | 13.4902
Gen 78 | 13.9466 | 13.6426
Gen 79 | 14.0811 | 13.7798
Gen 80 | 14.2046 | 13.8981
Gen 81 | 14.3261 | 14.0233
Gen 82 | 14.4108 | 14.1011
Gen 83 | 14.4789 | 14.1806
Gen 84 | 14.5301 | 14.2264
Gen 85 | 14.5796 | 14.2757
Gen 86 | 14.6326 | 14.3242
Gen 87 | 14.6717 | 14.3613
Gen 88 | 14.7112 | 14.4060
Gen 89 | 14.7555 | 14.4467
Gen 90 | 14.7982 | 14.4895
Gen 91 | 14.8162 | 14.5124
Gen 92 | 14.8249 | 14.5201
Gen 93 | 14.8360 | 14.5312
Gen 94 | 14.8391 | 14.5343
Gen 95 | 14.8411 | 14.5362
Gen 96 | 14.8418 | 14.5369
Gen 97 | 14.8421 | 14.5372
Gen 98 | 14.8421 | 14.5372
Gen 99 | 14.8421 | 14.5372
Gen 100 | 14.8421 | 14.5372
Total Genetic Gain: 14.5372