import msprime
import sys
print(f"msprime version being used: {msprime.__version__}")
print(f"Python executable: {sys.executable}")
msprime version being used: 1.3.4
Python executable: /opt/hostedtoolcache/Python/3.10.18/x64/bin/python
msprime version being used: 1.3.4
Python executable: /opt/hostedtoolcache/Python/3.10.18/x64/bin/python
import jax
import jax.numpy as jnp
# Assume the new meiosis module is available
from chewc.meiosis import produce_offspring
from chewc.config import SimConfig
from chewc.structs import SimParam
from chewc.popgen import quick_haplo
from chewc.state import init_state_from_founders
from chewc.trait import add_trait_a, set_pheno_h2
# 1. Define static simulation configuration
config = SimConfig(
n_chr=10,
ploidy=2,
max_pop_size=1000,
n_loci_per_chr=100,
n_generations=50,
n_select=50,
population_size=100,
)
# 2. Create and split master random key
key = jax.random.PRNGKey(42)
key, founder_key, trait_key, pheno_key, state_key, selection_key, meiosis_key = jax.random.split(key, 7)
# 3. Generate the founder population
founder_pop, genetic_map = quick_haplo(
key=founder_key,
n_ind=100,
n_chr=config.n_chr,
n_loci_per_chr=config.n_loci_per_chr,
max_pop_size=config.max_pop_size,
ploidy=config.ploidy,
)
# 4. Define initial simulation parameters
sp = SimParam(
gen_map=genetic_map,
ploidy=config.ploidy,
# Set a default recombination parameter `v` for the Gamma process
recomb_params=(2.6, 0.0, 0.0)
)
# 5. Add an additive trait to the simulation parameters
print("--- Defining Trait ---")
sp = add_trait_a(
key=trait_key,
founder_pop=founder_pop,
sim_param=sp,
n_qtl_per_chr=10,
mean=jnp.array([0.0]),
var=jnp.array([1.0]), # Target *genetic* variance
)
print(f"Trait added with {sp.traits.n_loci} QTLs.")
# 6. Set phenotypes using a target narrow-sense heritability (h2)
print("\n--- Setting Initial Phenotypes with h2 ---")
h2 = 0.4
founder_pop_with_pheno = set_pheno_h2(
key=pheno_key, pop=founder_pop, sp=sp, h2=h2
)
# 7. Initialize the dynamic simulation state
print("\n--- Initializing Simulation State ---")
initial_state = init_state_from_founders(
key=state_key,
founder_pop=founder_pop_with_pheno,
sp=sp,
config=config,
)
print(f"Initial write position: {initial_state.write_pos}")
print(f"Next available ID: {initial_state.next_id}")
print(f"Founder population active: {jnp.sum(initial_state.is_active)}")
# 8. Verification
print("\n--- Verification ---")
active_mask = initial_state.is_active
var_a = jnp.var(initial_state.bv[active_mask])
var_p = jnp.var(initial_state.pheno[active_mask])
realized_h2 = var_a / var_p
print(f"Target narrow-sense heritability (h2): {h2:.4f}")
print(f"Realized additive variance (VarA) in founders: {var_a:.4f}")
print(f"Realized phenotypic variance (VarP) in founders: {var_p:.4f}")
print(f"Realized narrow-sense heritability (h2) in founders: {realized_h2:.4f}")
# ==============================================================================
# --- NEW: Selection and Meiosis Workflow ---
# ==============================================================================
print("\n--- 9. Phenotypic Selection ---")
# In JAX, we operate on the full, fixed-size arrays and use masks.
# For selection, we can set the phenotype of inactive individuals to a very
# low number to ensure they are not selected.
pheno = initial_state.pheno[:, 0] # Assuming one trait
selectable_pheno = jnp.where(initial_state.is_active, pheno, -jnp.inf)
# Select the top 20% (20 individuals from the 100 active founders)
n_to_select = 20
# Use lax.top_k, which is more efficient than argsort for selection.
_, top_indices = jax.lax.top_k(selectable_pheno, k=n_to_select)
print(f"Selected {len(top_indices)} individuals with top phenotypes.")
print(f"Indices of selected individuals: {top_indices}")
print("\n--- 10. Random Mating ---")
# Randomly shuffle the selected indices to create mating pairs
shuffled_indices = jax.random.permutation(selection_key, top_indices)
# Split the shuffled group into mothers and fathers
n_crosses = n_to_select // 2
mother_indices = shuffled_indices[:n_crosses]
father_indices = shuffled_indices[n_crosses:]
print(f"Created {n_crosses} random pairs for mating.")
print(f"Mother indices: {mother_indices}")
print(f"Father indices: {father_indices}")
print("\n--- 11. Produce Offspring via Meiosis ---")
# Call the top-level, JIT-compatible meiosis kernel
# This will be extremely fast on a GPU as it vmaps over all crosses.
offspring_geno, offspring_ibd = produce_offspring(
key=meiosis_key,
state=initial_state,
sp=sp,
config=config,
mother_indices=mother_indices,
father_indices=father_indices,
)
print("Meiosis complete.")
print(f"Shape of offspring geno array: {offspring_geno.shape}")
print(f"Shape of offspring IBD array: {offspring_ibd.shape}")
# The `offspring_geno` and `offspring_ibd` arrays are now ready to be
# written into the main `SimState` array in the next generation step.
--- Defining Trait ---
Trait added with 100 QTLs.
--- Setting Initial Phenotypes with h2 ---
--- Initializing Simulation State ---
Initial write position: 100
Next available ID: 100
Founder population active: 100
--- Verification ---
Target narrow-sense heritability (h2): 0.4000
Realized additive variance (VarA) in founders: 1.0000
Realized phenotypic variance (VarP) in founders: 2.6278
Realized narrow-sense heritability (h2) in founders: 0.3805
--- 9. Phenotypic Selection ---
Selected 20 individuals with top phenotypes.
Indices of selected individuals: [93 68 29 31 87 34 26 92 71 88 77 33 59 75 30 0 18 52 39 35]
--- 10. Random Mating ---
Created 10 random pairs for mating.
Mother indices: [71 31 52 68 88 75 92 59 0 18]
Father indices: [39 29 35 87 93 77 33 34 26 30]
--- 11. Produce Offspring via Meiosis ---
Meiosis complete.
Shape of offspring geno array: (10, 10, 2, 100)
Shape of offspring IBD array: (10, 10, 2, 100)
import jax
import jax.numpy as jnp
from jax import vmap, lax
from functools import partial
# Assume the new meiosis module is available
from chewc.meiosis import produce_offspring
from chewc.config import SimConfig
from chewc.structs import SimParam
from chewc.state import SimState
from chewc.popgen import quick_haplo
from chewc.state import init_state_from_founders
from chewc.trait import add_trait_a, set_pheno_h2
# ==============================================================================
# --- 1. Host-Side Setup (Done Once) ---
# ==============================================================================
# Define static simulation configuration
config = SimConfig(
n_chr=10,
ploidy=2,
max_pop_size=1000,
n_loci_per_chr=100,
n_generations=10,
n_select=50,
population_size=100,
retention_generations=2 # Keep parents and grandparents
)
# Create and split master random key
key = jax.random.PRNGKey(42)
key, founder_key, trait_key, pheno_key = jax.random.split(key, 4)
# Generate the founder population
founder_pop, genetic_map = quick_haplo(
key=founder_key,
n_ind=100,
n_chr=config.n_chr,
n_loci_per_chr=config.n_loci_per_chr,
max_pop_size=config.max_pop_size,
ploidy=config.ploidy,
)
# Define initial simulation parameters
sp = SimParam(
gen_map=genetic_map,
ploidy=config.ploidy,
recomb_params=(2.6, 0.0, 0.0)
)
# Add an additive trait
print("--- Defining Trait ---")
sp = add_trait_a(
key=trait_key,
founder_pop=founder_pop,
sim_param=sp,
n_qtl_per_chr=10,
mean=jnp.array([0.0]),
var=jnp.array([1.0]),
)
print(f"Trait added with {sp.traits.n_loci} QTLs.")
# Set initial phenotypes
print("\n--- Setting Initial Phenotypes with h2 ---")
h2 = 0.4
founder_pop_with_pheno = set_pheno_h2(
key=pheno_key, pop=founder_pop, sp=sp, h2=h2
)
# ==============================================================================
# --- 2. JIT-Compiled Generation Kernel ---
# ==============================================================================
@partial(jax.jit, static_argnames=("config",))
def generation_step(state: SimState, sp: SimParam, config: SimConfig) -> SimState:
"""
Performs a single generation step: selection, mating, and meiosis.
This function is pure and JIT-compatible.
"""
key, pheno_key, selection_key, meiosis_key = jax.random.split(state.key, 4)
# 1. Phenotypic Selection
pheno = state.pheno[:, 0]
selectable_pheno = jnp.where(state.is_active, pheno, -jnp.inf)
_, top_indices = jax.lax.top_k(selectable_pheno, k=config.n_select)
# 2. Random Mating
shuffled_indices = jax.random.permutation(selection_key, top_indices)
n_pairs = config.n_select // 2
mother_indices_base = shuffled_indices[:n_pairs]
father_indices_base = shuffled_indices[n_pairs:]
n_offspring_per_pair = config.population_size // n_pairs
mother_indices = jnp.tile(mother_indices_base, n_offspring_per_pair)
father_indices = jnp.tile(father_indices_base, n_offspring_per_pair)
# 3. Produce Offspring
offspring_geno, offspring_ibd = produce_offspring(
key=meiosis_key,
state=state,
sp=sp,
config=config,
mother_indices=mother_indices,
father_indices=father_indices,
)
# 4. Create new cohort and write to state
write_pos = state.write_pos
new_ids = state.next_id + jnp.arange(config.population_size)
# --- FIX WAS HERE ---
# Use `lax.dynamic_update_slice` for JIT-compatible slice updates with dynamic start indices.
# The start indices need to match the rank of the array being updated.
geno = lax.dynamic_update_slice(state.geno, offspring_geno, (write_pos, 0, 0, 0))
ibd = lax.dynamic_update_slice(state.ibd, offspring_ibd, (write_pos, 0, 0, 0))
id_col = lax.dynamic_update_slice(state.id, new_ids, (write_pos,))
mother = lax.dynamic_update_slice(state.mother, state.id[mother_indices], (write_pos,))
father = lax.dynamic_update_slice(state.father, state.id[father_indices], (write_pos,))
# Create broadcasted arrays for the values we're setting
new_gen_arr = jnp.full((config.population_size,), state.gen_idx + 1, dtype=jnp.int32)
gen = lax.dynamic_update_slice(state.gen, new_gen_arr, (write_pos,))
# Update the is_active mask
is_active_true = jnp.full((config.population_size,), True)
is_active = lax.dynamic_update_slice(state.is_active, is_active_true, (write_pos,))
# Deactivate individuals from `retention_generations` ago
if config.retention_generations > 0:
deactivation_pos = (write_pos - (config.retention_generations * config.population_size)) % config.max_pop_size
is_active_false = jnp.full((config.population_size,), False)
is_active = lax.dynamic_update_slice(is_active, is_active_false, (deactivation_pos,))
# Create a temporary population object to calculate new phenotypes
temp_pop = founder_pop.replace(geno=geno, is_active=is_active)
pop_with_new_pheno = set_pheno_h2(
key=pheno_key, pop=temp_pop, sp=sp, h2=h2
)
# 5. Return the updated state
return state.replace(
key=key,
geno=geno,
ibd=ibd,
pheno=pop_with_new_pheno.pheno,
bv=pop_with_new_pheno.bv,
is_active=is_active,
id=id_col,
mother=mother,
father=father,
gen=gen,
write_pos=((write_pos + config.population_size) % config.max_pop_size),
gen_idx=state.gen_idx + 1,
next_id=state.next_id + config.population_size,
)
# ==============================================================================
# --- 3. Running Replicates with vmap ---
# ==============================================================================
print("\n--- Initializing Replicates ---")
n_replicates = 3
key, *replicate_keys = jax.random.split(key, n_replicates + 1)
initial_states = [
init_state_from_founders(
key=k, founder_pop=founder_pop_with_pheno, sp=sp, config=config
)
for k in replicate_keys
]
batched_initial_state = jax.tree_util.tree_map(
lambda *x: jnp.stack(x), *initial_states
)
print(f"Created a batch of {n_replicates} initial states.")
print(f"Shape of batched geno: {batched_initial_state.geno.shape}")
print(f"Shape of batched keys: {batched_initial_state.key.shape}")
vmapped_generation_step = vmap(generation_step, in_axes=(0, None, None))
def run_simulation(initial_carry, n_steps):
final_carry, _ = lax.scan(
lambda carry, _: (vmapped_generation_step(carry, sp, config), None),
initial_carry,
None,
length=n_steps
)
return final_carry
print(f"\n--- Running {n_replicates} Replicates for {config.n_generations} Generations ---")
final_states = run_simulation(batched_initial_state, config.n_generations)
print("Simulation complete.")
# ==============================================================================
# --- 4. Analyze Results ---
# ==============================================================================
print("\n--- Analysis of Final States ---")
for i in range(n_replicates):
final_state_rep = jax.tree_util.tree_map(lambda x: x[i], final_states)
active_mask = final_state_rep.is_active
final_gen_idx = final_state_rep.gen_idx
final_var_a = jnp.var(final_state_rep.bv[active_mask])
final_mean_bv = jnp.mean(final_state_rep.bv[active_mask])
print(f"\n--- Replicate {i+1} ---")
print(f" Final generation index: {final_gen_idx}")
print(f" Final mean breeding value: {final_mean_bv:.4f}")
print(f" Final additive variance (VarA): {final_var_a:.4f}")
--- Defining Trait ---
Trait added with 100 QTLs.
--- Setting Initial Phenotypes with h2 ---
--- Initializing Replicates ---
Created a batch of 3 initial states.
Shape of batched geno: (3, 1000, 10, 2, 100)
Shape of batched keys: (3, 2)
--- Running 3 Replicates for 10 Generations ---
Simulation complete.
--- Analysis of Final States ---
--- Replicate 1 ---
Final generation index: 10
Final mean breeding value: 6.4441
Final additive variance (VarA): 0.5065
--- Replicate 2 ---
Final generation index: 10
Final mean breeding value: 6.6405
Final additive variance (VarA): 0.5850
--- Replicate 3 ---
Final generation index: 10
Final mean breeding value: 6.3665
Final additive variance (VarA): 0.4618