Commit c6528be8 authored by Maximilian LECHNER's avatar Maximilian LECHNER
Browse files

Removed run_brownian_neighbor_list_unsafe since run_brownian_neighbor_list is...

Removed run_brownian_neighbor_list_unsafe since run_brownian_neighbor_list is now differentiable due to a change in jax-md.
parent 57538da6
......@@ -51,56 +51,7 @@ def run_brownian(
return final_state, data
# DIFFERENTIABLE but dangerous! ALWAYS CHECK IF THE BUFFER OVERFLOWED
def run_brownian_neighbor_list_unsafe(
energy_fn,
neighbor_fn,
nbrs,
R_init,
shift,
key,
num_total_steps,
record_every=1000,
dt=0.0001,
kT=1.0,
gamma=1.0,
measure_fn=lambda R, nbrs: R,
**static_kwargs
):
init, apply = simulate.brownian(energy_fn, shift, dt=dt, kT=kT, gamma=gamma)
@jit
def apply_single_step(state_nbrs, t):
state, nbrs = state_nbrs
nbrs = neighbor_fn(state.position, nbrs)
state = apply(state, neighbor=nbrs, **static_kwargs)
return (state, nbrs), 0
@jit
def apply_many_steps(state_nbrs, t_list):
state_nbrs, _ = lax.scan(apply_single_step, state_nbrs, t_list)
return (
state_nbrs,
measure_fn(state_nbrs[0].position, nbrs=state_nbrs[1], **static_kwargs),
)
key, split = random.split(key)
initial_state = init(split, R_init)
# run the simulation
final_state_nbrs, data = lax.scan(
apply_many_steps,
(initial_state, nbrs),
jnp.arange(num_total_steps).reshape(
num_total_steps // record_every, record_every
),
)
final_state, final_nbrs = final_state_nbrs
return final_state, data, final_nbrs
# NON DIFFERENTIABLE but safe
# DIFFERENTIABLE and uses neighbor_lists
def run_brownian_neighbor_list(
energy_fn,
neighbor_fn,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment