import jax
import jax.numpy as jnp
import numpy as np

# Import core chewc components
from chewc.structs import add_trait
from chewc.burnin import run_burnin

# UPDATED: Import the pipeline runner from chewc.pipe
from chewc.pipe import run_simulation_cycles

# --- 1. Experiment Parameters ---
N_POP = 150             # Constant population size
N_ENVIRONMENTS = 1      # Single Location/Trait
N_CHR = 5
N_LOCI = 1000
N_QTL = 50
SEED = 123

# Simulation Settings
BURN_IN_GENS = 50
MAX_CROSSOVERS = 10

# Selection Settings
N_SELECTION_GENS = 50
N_SELECT = 20           # Select top 20 parents
N_OFFSPRING = 150       # Next gen size

# --- 2. Burn-in (Initialize & Establish LD) ---
print(f"--- Setting up Single-Trait Experiment ---")
print(f"Population Size: {N_POP}")
print(f"Environments: {N_ENVIRONMENTS} (Single Trait)")
print(f"\n[Phase 1] Running {BURN_IN_GENS} generations of burn-in...")

key = jax.random.PRNGKey(SEED)
key, burnin_key, trait_key = jax.random.split(key, 3)

# Single call to handle initialization and burn-in
stable_state, final_ld, genetic_map = run_burnin(
    key=burnin_key,
    n_gens=BURN_IN_GENS,
    n_pop=N_POP,
    n_chr=N_CHR,
    n_loci=N_LOCI,
    max_crossovers=MAX_CROSSOVERS
)

print(f"Burn-in complete at Generation {stable_state.generation}")
print(f"Mean Adjacent LD (r^2) per chromosome: {final_ld}")


# --- 3. Define Single Trait Architecture ---
print(f"\n[Phase 2] Defining Single Trait Architecture...")

# 1x1 Correlation matrix (scalar 1.0)
genetic_correlation = jnp.eye(N_ENVIRONMENTS)

trait_arch = add_trait(
    key=trait_key,
    founder_pop=stable_state.population,
    n_qtl_per_chr=N_QTL,
    mean=jnp.zeros(N_ENVIRONMENTS),
    var_a=jnp.ones(N_ENVIRONMENTS),
    var_d=jnp.zeros(N_ENVIRONMENTS),
    sigma=genetic_correlation
)

# Single heritability value
HERITABILITIES = jnp.array([0.5])


# --- 4. Execute Selection Loop ---
print(f"\n--- Starting {N_SELECTION_GENS} Generations of Selection ---")

# REFACTORED: Use run_simulation_cycles from chewc.pipe
# This handles the partial binding and lax.scan internally
final_state, history = run_simulation_cycles(
    initial_state=stable_state,
    trait=trait_arch,
    genetic_map=genetic_map,
    heritabilities=HERITABILITIES,
    n_cycles=N_SELECTION_GENS,
    n_select=N_SELECT,
    n_offspring=N_OFFSPRING,
    max_crossovers=MAX_CROSSOVERS
)


# --- 5. Results Analysis ---
print("\nGeneration | Mean TBV (Genetic Gain) | Mean Phenotype")
print("---------------------------------------------------")
metrics_history = np.array(history)

for i in range(N_SELECTION_GENS):
    gen = stable_state.generation + i + 1
    tbv = metrics_history[i, 0]
    pheno = metrics_history[i, 1]
    print(f"Gen {gen:<3}    | {tbv:<23.4f} | {pheno:.4f}")

total_gain = metrics_history[-1, 0] - metrics_history[0, 0]
print(f"\nTotal Genetic Gain: {total_gain:.4f}")
/home/glect/.local/lib/python3.10/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.
  warnings.warn("Unable to import Axes3D. This may be due to multiple versions of "
WARNING:2025-11-26 16:51:28,127:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
--- Setting up Single-Trait Experiment ---
Population Size: 150
Environments: 1 (Single Trait)

[Phase 1] Running 50 generations of burn-in...
Burn-in complete at Generation 50
Mean Adjacent LD (r^2) per chromosome: [0.15110835 0.14266446 0.14125235 0.14965521 0.1522844 ]

[Phase 2] Defining Single Trait Architecture...

--- Starting 50 Generations of Selection ---

Generation | Mean TBV (Genetic Gain) | Mean Phenotype
---------------------------------------------------
Gen 51     | 0.3050                  | 0.0893
Gen 52     | 1.5707                  | 1.3299
Gen 53     | 2.5418                  | 2.2419
Gen 54     | 3.4875                  | 3.1454
Gen 55     | 4.3080                  | 3.9606
Gen 56     | 5.1795                  | 4.8874
Gen 57     | 5.7102                  | 5.3484
Gen 58     | 6.5932                  | 6.2297
Gen 59     | 7.3113                  | 7.1080
Gen 60     | 7.8562                  | 7.5016
Gen 61     | 8.4342                  | 8.0798
Gen 62     | 8.9574                  | 8.6164
Gen 63     | 9.4117                  | 9.1064
Gen 64     | 9.9187                  | 9.5113
Gen 65     | 10.3476                 | 10.0543
Gen 66     | 10.7954                 | 10.5121
Gen 67     | 11.1176                 | 10.8407
Gen 68     | 11.4818                 | 11.2164
Gen 69     | 11.8435                 | 11.5521
Gen 70     | 12.1928                 | 11.9043
Gen 71     | 12.4904                 | 12.1992
Gen 72     | 12.7797                 | 12.4851
Gen 73     | 13.0009                 | 12.7128
Gen 74     | 13.1698                 | 12.8855
Gen 75     | 13.4130                 | 13.0821
Gen 76     | 13.6155                 | 13.3147
Gen 77     | 13.8012                 | 13.4902
Gen 78     | 13.9466                 | 13.6426
Gen 79     | 14.0811                 | 13.7798
Gen 80     | 14.2046                 | 13.8981
Gen 81     | 14.3261                 | 14.0233
Gen 82     | 14.4108                 | 14.1011
Gen 83     | 14.4789                 | 14.1806
Gen 84     | 14.5301                 | 14.2264
Gen 85     | 14.5796                 | 14.2757
Gen 86     | 14.6326                 | 14.3242
Gen 87     | 14.6717                 | 14.3613
Gen 88     | 14.7112                 | 14.4060
Gen 89     | 14.7555                 | 14.4467
Gen 90     | 14.7982                 | 14.4895
Gen 91     | 14.8162                 | 14.5124
Gen 92     | 14.8249                 | 14.5201
Gen 93     | 14.8360                 | 14.5312
Gen 94     | 14.8391                 | 14.5343
Gen 95     | 14.8411                 | 14.5362
Gen 96     | 14.8418                 | 14.5369
Gen 97     | 14.8421                 | 14.5372
Gen 98     | 14.8421                 | 14.5372
Gen 99     | 14.8421                 | 14.5372
Gen 100    | 14.8421                 | 14.5372

Total Genetic Gain: 14.5372