Commit 90b46114 authored by Moyuan CHEN's avatar Moyuan CHEN
Browse files

Merge branch 'revert-51b52b25' into 'team'

Revert "Update utils.py"

See merge request !5
parents 51b52b25 9379118a
......@@ -153,10 +153,10 @@ def get_species_from_distribution(N, species_distribution, key=None):
returns: [0,0,1,1,2,2,3,3,4,4]
"""
species = jnp.zeros(shape=(N,))
species = np.zeros(shape=(N,))
species_distribution = species_distribution / jnp.sum(species_distribution)
species_dist_N = jnp.array(species_distribution * N).astype(int)
species_distribution = species_distribution / np.sum(species_distribution)
species_dist_N = np.array(species_distribution * N).astype(int)
def fill_in_number_body_fn(ind, sv):
s, v = sv
......@@ -173,7 +173,7 @@ def get_species_from_distribution(N, species_distribution, key=None):
p = species_distribution * N - species_dist_N
if key is None:
additions = jnp.zeros(N, dtype=jnp.int32).at[N-p.shape[0]:].set(jnp.argsort(p))
additions = np.zeros(N, dtype=np.int32).at[N-p.shape[0]:].set(np.argsort(p))
else:
additions = jnp.zeros(N, dtype=jnp.int32).at[N-p.shape[0]:].set(random.choice(key, species_distribution.shape[0], (N-particle_index,), False, p=p))
return jnp.sort(jnp.where(jnp.arange(N)>=particle_index, additions, species))
additions = np.zeros(N, dtype=np.int32).at[N-p.shape[0]:].set(random.choice(key, species_distribution.shape[0], (N-particle_index,), False, p=p))
return np.sort(np.where(np.arange(N)>=particle_index, additions, species))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment