Commit 9d8183b7 authored by Carl Goodrich's avatar Carl Goodrich
Browse files

add MoYuan's get_species_from_distribution function

parent 80671e9e
......@@ -133,3 +133,54 @@ def vector2symmat_diag0(v):
"""
n = int(((1 + 8 * v.shape[0]) ** 0.5 + 1) / 2)
return _vector2symmat_diag0(v, jnp.zeros((n, n), dtype=v.dtype))
def get_species_from_distribution(N, species_distribution, key=None):
""" Convert an array of probability distributions of M species into an
array with length N and matching distribution of the species.
Args:
N: total number of particles
species_dist: array of shape (M,) indicating the desired distribution of
species, or an int M indicating an even distribution of M species
key [Optional]: RNG key to draw random numbers to fill out the remainder
of the species array when N is not a even multiple of M
Return:
1-D Array species that has length N and specified distribution from
species_dist
Example:
N = 10, species_dist = jnp.array([0.2,0.2,0.2,0.2,0.2])
returns: [0,0,1,1,2,2,3,3,4,4]
"""
if isinstance(species_distribution, int):
_species_distribution = jnp.ones((species_distribution,), dtype=jnp.float32) / species_distribution
else:
_species_distribution = species_distribution
species = jnp.zeros(shape=(N,), dtype=jnp.int32)
_species_distribution = _species_distribution / jnp.sum(_species_distribution)
species_dist_N = jnp.array(_species_distribution * N).astype(int)
particle_index = 0
species_index = 0
for n_species_i in species_dist_N:
#species = species.at[particle_index:particle_index + n_species_i].set(jnp.full((n_species_i,), species_index, dtype=jnp.int32))
species = species.at[particle_index:particle_index + n_species_i].set(species_index * jnp.ones((n_species_i,), dtype=jnp.int32))
species_index += 1
particle_index += n_species_i
p = _species_distribution * N - species_dist_N
additions = jnp.zeros(N, dtype=jnp.int32)
if key is None:
additions = additions.at[N-p.shape[0]:].set(jnp.argsort(p))
else:
additions = additions.at[N-p.shape[0]:].set(random.choice(key, _species_distribution.shape[0], (p.shape[0],), True, p=p))
return jnp.sort(jnp.where(jnp.arange(N)>=particle_index, additions, species))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment