Commit d97aa7df authored by Carl Goodrich's avatar Carl Goodrich
Browse files

Merge branch 'add_tests' into 'team'

Add tests

See merge request !3
parents c8f2280f b40f0a81
...@@ -67,13 +67,13 @@ def run_brownian_neighbor_list( ...@@ -67,13 +67,13 @@ def run_brownian_neighbor_list(
measure_fn=lambda R, nbrs: R, measure_fn=lambda R, nbrs: R,
**static_kwargs **static_kwargs
): ):
nbrs = neighbor_fn(R_init) nbrs = neighbor_fn.allocate(R_init)
init, apply = simulate.brownian(energy_fn, shift, dt=dt, kT=kT, gamma=gamma) init, apply = simulate.brownian(energy_fn, shift, dt=dt, kT=kT, gamma=gamma)
def body_fn(state_nbrs, t): def body_fn(state_nbrs, t):
state, nbrs = state_nbrs state, nbrs = state_nbrs
nbrs = neighbor_fn(state.position, nbrs) nbrs = neighbor_fn.update(state.position, nbrs)
state = apply(state, neighbor=nbrs, **static_kwargs) state = apply(state, neighbor=nbrs, **static_kwargs)
return (state, nbrs), 0 return (state, nbrs), 0
...@@ -93,7 +93,7 @@ def run_brownian_neighbor_list( ...@@ -93,7 +93,7 @@ def run_brownian_neighbor_list(
# the simulation. # the simulation.
if nbrs.did_buffer_overflow: if nbrs.did_buffer_overflow:
print("Buffer overflow.") print("Buffer overflow.")
nbrs = neighbor_fn(state.position) nbrs = neighbor_fn.allocate(state.position)
else: else:
state = new_state state = new_state
step += step_inc step += step_inc
......
...@@ -65,7 +65,7 @@ def run_minimization_while_neighbor_list( ...@@ -65,7 +65,7 @@ def run_minimization_while_neighbor_list(
verbose=False, verbose=False,
**kwargs **kwargs
): ):
nbrs = neighbor_fn(R_init) nbrs = neighbor_fn.allocate(R_init)
init, apply = minimize.fire_descent(jit(energy_fn), shift, **kwargs) init, apply = minimize.fire_descent(jit(energy_fn), shift, **kwargs)
apply = jit(apply) apply = jit(apply)
...@@ -77,7 +77,7 @@ def run_minimization_while_neighbor_list( ...@@ -77,7 +77,7 @@ def run_minimization_while_neighbor_list(
@jit @jit
def body_fn(state_nbrs, t): def body_fn(state_nbrs, t):
state, nbrs = state_nbrs state, nbrs = state_nbrs
nbrs = neighbor_fn(state.position, nbrs) nbrs = neighbor_fn.update(state.position, nbrs)
state = apply(state, neighbor=nbrs) state = apply(state, neighbor=nbrs)
return (state, nbrs), 0 return (state, nbrs), 0
...@@ -93,7 +93,7 @@ def run_minimization_while_neighbor_list( ...@@ -93,7 +93,7 @@ def run_minimization_while_neighbor_list(
# the simulation. # the simulation.
if nbrs.did_buffer_overflow: if nbrs.did_buffer_overflow:
print("Buffer overflow.") print("Buffer overflow.")
nbrs = neighbor_fn(state.position) nbrs = neighbor_fn.allocate(state.position)
else: else:
state = new_state state = new_state
step += step_inc step += step_inc
......
from absl.testing import absltest
from absl.testing import parameterized
from jax.config import config as jax_config
from jax import test_util as jtu
from jax import random
import jax.numpy as jnp
from jax_md import energy, space, quantity
from jax_md.util import *
from common_utils import brownian
from jax.config import config as jax_config
jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS
PARTICLE_COUNT = 256
SPATIAL_DIMENSION = [2, 3]
STOCHASTIC_SAMPLES = 1
if FLAGS.jax_enable_x64:
DTYPE = [f32, f64]
else:
DTYPE = [f32]
def setup_system(N, dimension, key, dtype):
diameters = 1.0
box_size = quantity.box_size_at_number_density(N, 0.4, dimension)
displacement, shift = space.periodic(box_size)
energy_fn = energy.soft_sphere_pair(displacement, sigma=diameters)
R_init = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=dtype)
return displacement, shift, energy_fn, R_init, box_size
def setup_system_nl(N, dimension, key, dtype):
diameters = 1.0
box_size = quantity.box_size_at_number_density(N, 0.4, dimension)
displacement, shift = space.periodic(box_size)
#energy_fn = energy.soft_sphere_pair(displacement, sigma=diameters)
neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
displacement, box_size, sigma=diameters, dr_threshold = 0.2,
capacity_multiplier = 1.5)
R_init = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=dtype)
return displacement, shift, energy_fn, neighbor_fn, R_init, box_size
class BrownianTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dim={}_dtype={}'.format(
dim, dtype.__name__),
'spatial_dimension': dim,
'dtype': dtype
} for dim in SPATIAL_DIMENSION
for dtype in DTYPE))
def test_brownian(self, spatial_dimension, dtype):
key = random.PRNGKey(0)
for _ in range(STOCHASTIC_SAMPLES):
key, pos_key, brownian_key = random.split(key, 3)
_, shift, energy_fn, R_init, _ = setup_system(
PARTICLE_COUNT, spatial_dimension, pos_key, dtype)
nsteps, record_every = 1000, 10
final_state, trajectory = brownian.run_brownian(
energy_fn, R_init, shift, brownian_key,
num_total_steps=nsteps, record_every=record_every,
dt=0.0001, kT=0.01, gamma=0.1)
assert trajectory.shape == (nsteps//record_every, PARTICLE_COUNT, spatial_dimension)
self.assertAllClose(final_state.position, trajectory[-1])
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dim={}_dtype={}'.format(
dim, dtype.__name__),
'spatial_dimension': dim,
'dtype': dtype
} for dim in SPATIAL_DIMENSION
for dtype in DTYPE))
def test_measurement(self, spatial_dimension, dtype):
key = random.PRNGKey(0)
for _ in range(STOCHASTIC_SAMPLES):
key, pos_key, brownian_key = random.split(key, 3)
displacement, shift, energy_fn, R_init, box_size = setup_system(
PARTICLE_COUNT, spatial_dimension, pos_key, dtype)
rs = jnp.linspace(0,box_size/2.0, 101)[1:]
g_fn = quantity.pair_correlation(displacement, rs, 0.1)
def measurement(R):
return jnp.mean(g_fn(R),axis=0)
nsteps, record_every = 1000, 10
final_state, gofr_all = brownian.run_brownian(
energy_fn, R_init, shift, brownian_key,
num_total_steps=nsteps, record_every=record_every,
dt=0.0001, kT=0.01, gamma=0.1,
measure_fn = measurement)
assert gofr_all.shape == (nsteps//record_every, 100)
assert final_state.position.shape == R_init.shape
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dim={}_dtype={}'.format(
dim, dtype.__name__),
'spatial_dimension': dim,
'dtype': dtype
} for dim in SPATIAL_DIMENSION
for dtype in DTYPE))
def test_brownian_nl(self, spatial_dimension, dtype):
key = random.PRNGKey(0)
for _ in range(STOCHASTIC_SAMPLES):
key, pos_key, brownian_key = random.split(key, 3)
_, shift, energy_fn, neighbor_fn, R_init, _ = setup_system_nl(
PARTICLE_COUNT, spatial_dimension, pos_key, dtype)
nsteps, record_every = 1000, 10
final_state, trajectory, nbrs = brownian.run_brownian_neighbor_list(
energy_fn, neighbor_fn, R_init, shift, brownian_key,
num_steps=nsteps, step_inc=record_every,
dt=0.0001, kT=0.001, gamma=0.1)
assert trajectory.shape == (nsteps//record_every, PARTICLE_COUNT, spatial_dimension)
self.assertAllClose(final_state.position, trajectory[-1])
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dim={}_dtype={}'.format(
dim, dtype.__name__),
'spatial_dimension': dim,
'dtype': dtype
} for dim in SPATIAL_DIMENSION
for dtype in DTYPE))
def test_measurement_nl(self, spatial_dimension, dtype):
key = random.PRNGKey(0)
for _ in range(STOCHASTIC_SAMPLES):
key, pos_key, brownian_key = random.split(key, 3)
displacement, shift, energy_fn, neighbor_fn, R_init, box_size = setup_system_nl(
PARTICLE_COUNT, spatial_dimension, pos_key, dtype)
rs = jnp.linspace(0,4.0, 101)[1:]
gnbr_fn, g_fn = quantity.pair_correlation_neighbor_list(displacement, box_size, rs, 0.1)
gnbrs = gnbr_fn.allocate(R_init)
def measurement(R,nbrs,gnbrs):
gnbrs = gnbr_fn.update(R, gnbrs)
return jnp.mean(g_fn(R,gnbrs),axis=0)
nsteps, record_every = 1000, 10
final_state, gofr_all, nbrs = brownian.run_brownian_neighbor_list(
energy_fn, neighbor_fn, R_init, shift, brownian_key,
num_steps=nsteps, step_inc=record_every,
dt=0.0001, kT=0.001, gamma=0.1,
measure_fn=measurement, gnbrs=gnbrs)
assert gofr_all.shape == (nsteps//record_every, 100)
assert final_state.position.shape == R_init.shape
if __name__ == '__main__':
absltest.main()
from absl.testing import absltest
from jax import test_util as jtu
import jax.numpy as jnp
from common_utils import utils
from jax.config import config as jax_config
jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS
class utilsTest(jtu.JaxTestCase):
def test_vector2symmat(self):
v = jnp.array([0,1,2,3,4,5],dtype=jnp.float32)
m = utils.vector2symmat(v)
m_expected = jnp.array([[0,1,2],
[1,3,4],
[2,4,5]],dtype=v.dtype)
self.assertAllClose(m, m_expected)
def test_vector2symmat_diag0(self):
v = jnp.array([0,1,2,3,4,5],dtype=jnp.float32)
m = utils.vector2symmat_diag0(v)
m_expected = jnp.array([[0,0,1,2],
[0,0,3,4],
[1,3,0,5],
[2,4,5,0]],dtype=v.dtype)
self.assertAllClose(m, m_expected)
if __name__ == '__main__':
absltest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment