Commit d55cebf2 authored by Moyuan CHEN's avatar Moyuan CHEN
Browse files

Update common_utils/utils.py

parent 9518e86f
......@@ -135,8 +135,8 @@ def vector2symmat_diag0(v):
return _vector2symmat_diag0(v, jnp.zeros((n, n), dtype=v.dtype))
def get_species_from_distribution(N, species_dist):
""" Convert an array of probability distributions of M species into an
def get_species_from_distribution(N, species_distribution, key=None):
""" Convert an array of probability distributions of M species into an
array with length N and matching distribution of the species.
Args:
......@@ -153,29 +153,29 @@ def get_species_from_distribution(N, species_dist):
returns: [0,0,1,1,2,2,3,3,4,4]
"""
N_array = np.zeros(shape=(N,))
@jit
def jittable_sfd(species = N_array, species_dist= species_dist):
species_dist = species_dist / np.sum(species_dist)
species_dist_N = np.array(species_dist * N).astype(int)
def fill_in_number_body_fn(ind, sv):
s, v = sv
s = s.at[ind].set(v)
return (s,v)
particle_index = 0
species_index = 0
for n_species_i in species_dist_N:
species, _ = fori_loop(particle_index, particle_index + n_species_i, fill_in_number_body_fn, (species, species_index))
species_index += 1
particle_index += n_species_i
most_popular_specie = np.argsort(-species_dist_N)[0]
species = fori_loop(particle_index, N, fill_in_number_body_fn, (species, most_popular_specie))
return species
species = jittable_sfd()
return species
species = np.zeros(shape=(N,))
species_distribution = species_distribution / np.sum(species_distribution)
species_dist_N = np.array(species_distribution * N).astype(int)
def fill_in_number_body_fn(ind, sv):
s, v = sv
s = s.at[ind].set(v)
return (s,v)
particle_index = 0
species_index = 0
for n_species_i in species_dist_N:
species, _ = fori_loop(particle_index, particle_index + n_species_i, fill_in_number_body_fn, (species, species_index))
species_index += 1
particle_index += n_species_i
p = species_distribution * N - species_dist_N
if key is None:
additions = np.zeros(N, dtype=np.int32).at[N-p.shape[0]:].set(np.argsort(p))
else:
additions = np.zeros(N, dtype=np.int32).at[N-p.shape[0]:].set(random.choice(key, species_distribution.shape[0], (N-particle_index,), False, p=p))
return np.sort(np.where(np.arange(N)>=particle_index, additions, species))
get_species_from_distribution = jit(get_species_from_distribution, static_argnums=(0,))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment