Commit 72491d18 authored by Maximilian LECHNER's avatar Maximilian LECHNER
Browse files

Removed Brownian notebook. To be replaced with joint minimization/brownian notebook.

parent 4e3f7acb
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
import numpy as onp
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
from jax import random
from jax import jit, lax, grad, vmap, custom_vjp, hessian, jacfwd
from jax_md import space, smap, energy, minimize, quantity, simulate
f32 = jnp.float32
f64 = jnp.float64
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 16})
def format_plot(x, y):
plt.grid(True)
plt.xlabel(x, fontsize=20)
plt.ylabel(y, fontsize=20)
def finalize_plot(shape=(1, 0.7)):
plt.gcf().set_size_inches(
shape[0] * 1.5 * plt.gcf().get_size_inches()[1],
shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
```
%% Cell type:code id: tags:
``` python
#This version is fully differentiable and internally jitted
def run_minimization_scan(energy_fn, R_init, shift, num_steps=5000, **kwargs):
init,apply=minimize.fire_descent(jit(energy_fn), shift, **kwargs)
apply = jit(apply)
@jit
def scan_fn(state, i):
return apply(state), 0.
state = init(R_init)
state, _ = lax.scan(scan_fn,state,jnp.arange(num_steps))
return state.position, jnp.amax(jnp.abs(-grad(energy_fn)(state.position)))
#This version is internally jitted but only forward mode differentiable only
#The benefit is that it terminates when properly minimized
#Always use this version unless you need backward AD
def run_minimization_while(energy_fn, R_init, shift, max_grad_thresh = 1e-12, max_num_steps=1000000, **kwargs):
init,apply=minimize.fire_descent(jit(energy_fn), shift, **kwargs)
apply = jit(apply)
@jit
def get_maxgrad(state):
return jnp.amax(jnp.abs(state.force))
@jit
def cond_fn(val):
state, i = val
return jnp.logical_and(get_maxgrad(state) > max_grad_thresh, i<max_num_steps)
@jit
def body_fn(val):
state, i = val
return apply(state), i+1
state = init(R_init)
state, num_iterations = lax.while_loop(cond_fn, body_fn, (state, 0))
return state.position, get_maxgrad(state), num_iterations
#Run a brownian dynamics simulation and save periodic snapshots
def run_brownian(energy_fn, R_init, shift, key, nun_total_steps, record_every, dt, **kwargs):
#define the simulation
init, apply = simulate.brownian(energy_fn, shift, dt, **kwargs)
apply = jit(apply)
@jit
def apply_single_step(state, t):
return apply(state, t=t), 0
@jit
def apply_many_steps(state, t_list):
state, _ = lax.scan(apply_single_step, state, t_list)
return state, state.position
#initialize the system
key, split = random.split(key)
initial_state = init(split, R_init)
#run the simulation
final_state, trajectory = lax.scan(apply_many_steps, initial_state, jnp.arange(nun_total_steps).reshape(nun_total_steps//record_every,record_every))
#return the trajectory
return trajectory
```
%% Cell type:code id: tags:
``` python
def box_size_at_packing_fraction_2d(phi,diameters):
Vspheres = jnp.sum(vmap(lambda d: jnp.pi*(d/2)**2)(diameters))
return (Vspheres/phi)**(0.5)
N = 50
dimension = 2
diameters = jnp.linspace(1.0,1.4,N)
box_size = box_size_at_packing_fraction_2d(0.9, diameters)
displacement, shift = space.periodic(box_size)
energy_fn = energy.soft_sphere_pair(displacement, sigma=diameters)
key = random.PRNGKey(4)
key, split = random.split(key)
R_init = random.uniform(split, (N,dimension), minval=0.0, maxval=box_size, dtype=f64)
```
%% Cell type:code id: tags:
``` python
R_final, maxgrad = run_minimization_scan(energy_fn, R_init, shift, 100)
print('minimization ran for {} steps'.format(100))
print(maxgrad)
```
%% Cell type:code id: tags:
``` python
R_final, maxgrad, i_last = run_minimization_while(energy_fn, R_init, shift)
print('minimization ran for {} steps'.format(i_last))
print(maxgrad)
```
%% Cell type:code id: tags:
``` python
key, split = random.split(key)
trajectory = run_brownian(energy_fn, R_final, shift, split, 10000, 100, dt=0.0001, T_schedule=0.01, gamma=0.1)
```
%% Cell type:code id: tags:
``` python
```
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment