import msprime
import sys

print(f"msprime version being used: {msprime.__version__}")
print(f"Python executable: {sys.executable}")
msprime version being used: 1.3.4
Python executable: /opt/hostedtoolcache/Python/3.10.18/x64/bin/python
import jax
import jax.numpy as jnp

# Assume the new meiosis module is available
from chewc.meiosis import produce_offspring 
from chewc.config import SimConfig
from chewc.structs import SimParam
from chewc.popgen import quick_haplo
from chewc.state import init_state_from_founders
from chewc.trait import add_trait_a, set_pheno_h2

# 1. Define static simulation configuration
config = SimConfig(
    n_chr=10,
    ploidy=2,
    max_pop_size=1000,
    n_loci_per_chr=100,
    n_generations=50,
    n_select=50,
    population_size=100,
)

# 2. Create and split master random key
key = jax.random.PRNGKey(42)
key, founder_key, trait_key, pheno_key, state_key, selection_key, meiosis_key = jax.random.split(key, 7)

# 3. Generate the founder population
founder_pop, genetic_map = quick_haplo(
    key=founder_key,
    n_ind=100,
    n_chr=config.n_chr,
    n_loci_per_chr=config.n_loci_per_chr,
    max_pop_size=config.max_pop_size,
    ploidy=config.ploidy,
)

# 4. Define initial simulation parameters
sp = SimParam(
    gen_map=genetic_map,
    ploidy=config.ploidy,
    # Set a default recombination parameter `v` for the Gamma process
    recomb_params=(2.6, 0.0, 0.0)
)

# 5. Add an additive trait to the simulation parameters
print("--- Defining Trait ---")
sp = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=sp,
    n_qtl_per_chr=10,
    mean=jnp.array([0.0]),
    var=jnp.array([1.0]),  # Target *genetic* variance
)
print(f"Trait added with {sp.traits.n_loci} QTLs.")

# 6. Set phenotypes using a target narrow-sense heritability (h2)
print("\n--- Setting Initial Phenotypes with h2 ---")
h2 = 0.4
founder_pop_with_pheno = set_pheno_h2(
    key=pheno_key, pop=founder_pop, sp=sp, h2=h2
)

# 7. Initialize the dynamic simulation state
print("\n--- Initializing Simulation State ---")
initial_state = init_state_from_founders(
    key=state_key,
    founder_pop=founder_pop_with_pheno,
    sp=sp,
    config=config,
)

print(f"Initial write position: {initial_state.write_pos}")
print(f"Next available ID: {initial_state.next_id}")
print(f"Founder population active: {jnp.sum(initial_state.is_active)}")

# 8. Verification
print("\n--- Verification ---")
active_mask = initial_state.is_active
var_a = jnp.var(initial_state.bv[active_mask])
var_p = jnp.var(initial_state.pheno[active_mask])
realized_h2 = var_a / var_p

print(f"Target narrow-sense heritability (h2): {h2:.4f}")
print(f"Realized additive variance (VarA) in founders: {var_a:.4f}")
print(f"Realized phenotypic variance (VarP) in founders: {var_p:.4f}")
print(f"Realized narrow-sense heritability (h2) in founders: {realized_h2:.4f}")


# ==============================================================================
# --- NEW: Selection and Meiosis Workflow ---
# ==============================================================================

print("\n--- 9. Phenotypic Selection ---")

# In JAX, we operate on the full, fixed-size arrays and use masks.
# For selection, we can set the phenotype of inactive individuals to a very
# low number to ensure they are not selected.
pheno = initial_state.pheno[:, 0]  # Assuming one trait
selectable_pheno = jnp.where(initial_state.is_active, pheno, -jnp.inf)

# Select the top 20% (20 individuals from the 100 active founders)
n_to_select = 20
# Use lax.top_k, which is more efficient than argsort for selection.
_, top_indices = jax.lax.top_k(selectable_pheno, k=n_to_select)

print(f"Selected {len(top_indices)} individuals with top phenotypes.")
print(f"Indices of selected individuals: {top_indices}")


print("\n--- 10. Random Mating ---")

# Randomly shuffle the selected indices to create mating pairs
shuffled_indices = jax.random.permutation(selection_key, top_indices)

# Split the shuffled group into mothers and fathers
n_crosses = n_to_select // 2
mother_indices = shuffled_indices[:n_crosses]
father_indices = shuffled_indices[n_crosses:]

print(f"Created {n_crosses} random pairs for mating.")
print(f"Mother indices: {mother_indices}")
print(f"Father indices: {father_indices}")


print("\n--- 11. Produce Offspring via Meiosis ---")

# Call the top-level, JIT-compatible meiosis kernel
# This will be extremely fast on a GPU as it vmaps over all crosses.
offspring_geno, offspring_ibd = produce_offspring(
    key=meiosis_key,
    state=initial_state,
    sp=sp,
    config=config,
    mother_indices=mother_indices,
    father_indices=father_indices,
)

print("Meiosis complete.")
print(f"Shape of offspring geno array: {offspring_geno.shape}")
print(f"Shape of offspring IBD array: {offspring_ibd.shape}")

# The `offspring_geno` and `offspring_ibd` arrays are now ready to be
# written into the main `SimState` array in the next generation step.
--- Defining Trait ---
Trait added with 100 QTLs.

--- Setting Initial Phenotypes with h2 ---

--- Initializing Simulation State ---
Initial write position: 100
Next available ID: 100
Founder population active: 100

--- Verification ---
Target narrow-sense heritability (h2): 0.4000
Realized additive variance (VarA) in founders: 1.0000
Realized phenotypic variance (VarP) in founders: 2.6278
Realized narrow-sense heritability (h2) in founders: 0.3805

--- 9. Phenotypic Selection ---
Selected 20 individuals with top phenotypes.
Indices of selected individuals: [93 68 29 31 87 34 26 92 71 88 77 33 59 75 30  0 18 52 39 35]

--- 10. Random Mating ---
Created 10 random pairs for mating.
Mother indices: [71 31 52 68 88 75 92 59  0 18]
Father indices: [39 29 35 87 93 77 33 34 26 30]

--- 11. Produce Offspring via Meiosis ---
Meiosis complete.
Shape of offspring geno array: (10, 10, 2, 100)
Shape of offspring IBD array: (10, 10, 2, 100)
import jax
import jax.numpy as jnp
from jax import vmap, lax
from functools import partial

# Assume the new meiosis module is available
from chewc.meiosis import produce_offspring
from chewc.config import SimConfig
from chewc.structs import SimParam
from chewc.state import SimState
from chewc.popgen import quick_haplo
from chewc.state import init_state_from_founders
from chewc.trait import add_trait_a, set_pheno_h2

# ==============================================================================
# --- 1. Host-Side Setup (Done Once) ---
# ==============================================================================

# Define static simulation configuration
config = SimConfig(
    n_chr=10,
    ploidy=2,
    max_pop_size=1000,
    n_loci_per_chr=100,
    n_generations=10,
    n_select=50,
    population_size=100,
    retention_generations=2 # Keep parents and grandparents
)

# Create and split master random key
key = jax.random.PRNGKey(42)
key, founder_key, trait_key, pheno_key = jax.random.split(key, 4)

# Generate the founder population
founder_pop, genetic_map = quick_haplo(
    key=founder_key,
    n_ind=100,
    n_chr=config.n_chr,
    n_loci_per_chr=config.n_loci_per_chr,
    max_pop_size=config.max_pop_size,
    ploidy=config.ploidy,
)

# Define initial simulation parameters
sp = SimParam(
    gen_map=genetic_map,
    ploidy=config.ploidy,
    recomb_params=(2.6, 0.0, 0.0)
)

# Add an additive trait
print("--- Defining Trait ---")
sp = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=sp,
    n_qtl_per_chr=10,
    mean=jnp.array([0.0]),
    var=jnp.array([1.0]),
)
print(f"Trait added with {sp.traits.n_loci} QTLs.")

# Set initial phenotypes
print("\n--- Setting Initial Phenotypes with h2 ---")
h2 = 0.4
founder_pop_with_pheno = set_pheno_h2(
    key=pheno_key, pop=founder_pop, sp=sp, h2=h2
)

# ==============================================================================
# --- 2. JIT-Compiled Generation Kernel ---
# ==============================================================================

@partial(jax.jit, static_argnames=("config",))
def generation_step(state: SimState, sp: SimParam, config: SimConfig) -> SimState:
    """
    Performs a single generation step: selection, mating, and meiosis.
    This function is pure and JIT-compatible.
    """
    key, pheno_key, selection_key, meiosis_key = jax.random.split(state.key, 4)

    # 1. Phenotypic Selection
    pheno = state.pheno[:, 0]
    selectable_pheno = jnp.where(state.is_active, pheno, -jnp.inf)
    _, top_indices = jax.lax.top_k(selectable_pheno, k=config.n_select)

    # 2. Random Mating
    shuffled_indices = jax.random.permutation(selection_key, top_indices)
    n_pairs = config.n_select // 2
    mother_indices_base = shuffled_indices[:n_pairs]
    father_indices_base = shuffled_indices[n_pairs:]
    n_offspring_per_pair = config.population_size // n_pairs
    mother_indices = jnp.tile(mother_indices_base, n_offspring_per_pair)
    father_indices = jnp.tile(father_indices_base, n_offspring_per_pair)

    # 3. Produce Offspring
    offspring_geno, offspring_ibd = produce_offspring(
        key=meiosis_key,
        state=state,
        sp=sp,
        config=config,
        mother_indices=mother_indices,
        father_indices=father_indices,
    )

    # 4. Create new cohort and write to state
    write_pos = state.write_pos
    new_ids = state.next_id + jnp.arange(config.population_size)
    
    # --- FIX WAS HERE ---
    # Use `lax.dynamic_update_slice` for JIT-compatible slice updates with dynamic start indices.
    # The start indices need to match the rank of the array being updated.
    geno = lax.dynamic_update_slice(state.geno, offspring_geno, (write_pos, 0, 0, 0))
    ibd = lax.dynamic_update_slice(state.ibd, offspring_ibd, (write_pos, 0, 0, 0))
    id_col = lax.dynamic_update_slice(state.id, new_ids, (write_pos,))
    mother = lax.dynamic_update_slice(state.mother, state.id[mother_indices], (write_pos,))
    father = lax.dynamic_update_slice(state.father, state.id[father_indices], (write_pos,))
    
    # Create broadcasted arrays for the values we're setting
    new_gen_arr = jnp.full((config.population_size,), state.gen_idx + 1, dtype=jnp.int32)
    gen = lax.dynamic_update_slice(state.gen, new_gen_arr, (write_pos,))

    # Update the is_active mask
    is_active_true = jnp.full((config.population_size,), True)
    is_active = lax.dynamic_update_slice(state.is_active, is_active_true, (write_pos,))
    
    # Deactivate individuals from `retention_generations` ago
    if config.retention_generations > 0:
      deactivation_pos = (write_pos - (config.retention_generations * config.population_size)) % config.max_pop_size
      is_active_false = jnp.full((config.population_size,), False)
      is_active = lax.dynamic_update_slice(is_active, is_active_false, (deactivation_pos,))

    # Create a temporary population object to calculate new phenotypes
    temp_pop = founder_pop.replace(geno=geno, is_active=is_active)
    
    pop_with_new_pheno = set_pheno_h2(
        key=pheno_key, pop=temp_pop, sp=sp, h2=h2
    )

    # 5. Return the updated state
    return state.replace(
        key=key,
        geno=geno,
        ibd=ibd,
        pheno=pop_with_new_pheno.pheno,
        bv=pop_with_new_pheno.bv,
        is_active=is_active,
        id=id_col,
        mother=mother,
        father=father,
        gen=gen,
        write_pos=((write_pos + config.population_size) % config.max_pop_size),
        gen_idx=state.gen_idx + 1,
        next_id=state.next_id + config.population_size,
    )

# ==============================================================================
# --- 3. Running Replicates with vmap ---
# ==============================================================================

print("\n--- Initializing Replicates ---")
n_replicates = 3
key, *replicate_keys = jax.random.split(key, n_replicates + 1)

initial_states = [
    init_state_from_founders(
        key=k, founder_pop=founder_pop_with_pheno, sp=sp, config=config
    )
    for k in replicate_keys
]

batched_initial_state = jax.tree_util.tree_map(
    lambda *x: jnp.stack(x), *initial_states
)

print(f"Created a batch of {n_replicates} initial states.")
print(f"Shape of batched geno: {batched_initial_state.geno.shape}")
print(f"Shape of batched keys: {batched_initial_state.key.shape}")

vmapped_generation_step = vmap(generation_step, in_axes=(0, None, None))

def run_simulation(initial_carry, n_steps):
    final_carry, _ = lax.scan(
        lambda carry, _: (vmapped_generation_step(carry, sp, config), None),
        initial_carry,
        None,
        length=n_steps
    )
    return final_carry

print(f"\n--- Running {n_replicates} Replicates for {config.n_generations} Generations ---")
final_states = run_simulation(batched_initial_state, config.n_generations)
print("Simulation complete.")

# ==============================================================================
# --- 4. Analyze Results ---
# ==============================================================================
print("\n--- Analysis of Final States ---")

for i in range(n_replicates):
    final_state_rep = jax.tree_util.tree_map(lambda x: x[i], final_states)
    active_mask = final_state_rep.is_active

    final_gen_idx = final_state_rep.gen_idx
    final_var_a = jnp.var(final_state_rep.bv[active_mask])
    final_mean_bv = jnp.mean(final_state_rep.bv[active_mask])

    print(f"\n--- Replicate {i+1} ---")
    print(f"  Final generation index: {final_gen_idx}")
    print(f"  Final mean breeding value: {final_mean_bv:.4f}")
    print(f"  Final additive variance (VarA): {final_var_a:.4f}")
--- Defining Trait ---
Trait added with 100 QTLs.

--- Setting Initial Phenotypes with h2 ---

--- Initializing Replicates ---
Created a batch of 3 initial states.
Shape of batched geno: (3, 1000, 10, 2, 100)
Shape of batched keys: (3, 2)

--- Running 3 Replicates for 10 Generations ---
Simulation complete.

--- Analysis of Final States ---

--- Replicate 1 ---
  Final generation index: 10
  Final mean breeding value: 6.4441
  Final additive variance (VarA): 0.5065

--- Replicate 2 ---
  Final generation index: 10
  Final mean breeding value: 6.6405
  Final additive variance (VarA): 0.5850

--- Replicate 3 ---
  Final generation index: 10
  Final mean breeding value: 6.3665
  Final additive variance (VarA): 0.4618