### Merge branch 'mchen-team-patch-27828' into 'team'

```Update utils.py

See merge request !4```
parents d97aa7df d55cebf2
 ... ... @@ -133,3 +133,49 @@ def vector2symmat_diag0(v): """ n = int(((1 + 8 * v.shape) ** 0.5 + 1) / 2) return _vector2symmat_diag0(v, jnp.zeros((n, n), dtype=v.dtype)) def get_species_from_distribution(N, species_distribution, key=None): """ Convert an array of probability distributions of M species into an array with length N and matching distribution of the species. Args: N: total number of particles species_dist: the desired distribution of species key [Optional]: RNG key to draw random numbers to fill out the remainder of the species array when N is not a even multiple of M Return: 1-D Array species that has length N and specified distribution from species_dist Example: N = 10, species_dist = jnp.array([0.2,0.2,0.2,0.2,0.2]) returns: [0,0,1,1,2,2,3,3,4,4] """ species = np.zeros(shape=(N,)) species_distribution = species_distribution / np.sum(species_distribution) species_dist_N = np.array(species_distribution * N).astype(int) def fill_in_number_body_fn(ind, sv): s, v = sv s = s.at[ind].set(v) return (s,v) particle_index = 0 species_index = 0 for n_species_i in species_dist_N: species, _ = fori_loop(particle_index, particle_index + n_species_i, fill_in_number_body_fn, (species, species_index)) species_index += 1 particle_index += n_species_i p = species_distribution * N - species_dist_N if key is None: additions = np.zeros(N, dtype=np.int32).at[N-p.shape:].set(np.argsort(p)) else: additions = np.zeros(N, dtype=np.int32).at[N-p.shape:].set(random.choice(key, species_distribution.shape, (N-particle_index,), False, p=p)) return np.sort(np.where(np.arange(N)>=particle_index, additions, species)) get_species_from_distribution = jit(get_species_from_distribution, static_argnums=(0,))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment