Source code for mxtaltools.crystal_search.crystal_opt_utils

import copy
import gc
from argparse import Namespace
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch import optim
from torch_scatter import scatter
from tqdm import tqdm

from mxtaltools.analysis.crystal_rdf import compute_rdf_distance
from mxtaltools.common.geometry_utils import enforce_crystal_system
from mxtaltools.common.utils import is_cuda_oom
from mxtaltools.dataset_utils.utils import collate_data_list
from mxtaltools.models.utils import enforce_1d_bound, softmax_and_score


[docs] def dict2namespace(data_dict: dict): """ Recursively converts a dictionary and its internal dictionaries into an argparse.Namespace Parameters ---------- data_dict : dict The input dictionary Return ------ data_namespace : argparse.Namespace The output namespace """ for k, v in data_dict.items(): if isinstance(v, dict): data_dict[k] = dict2namespace(v) else: pass data_namespace = Namespace(**data_dict) return data_namespace
[docs] def get_annealing_factor(start_value, stop_value, total_time, step_iters): assert stop_value > 0, "Setting final value as zero breaks this module" return (stop_value / start_value) ** (1 / (total_time / step_iters))
[docs] class CrystalParams(nn.Module): def __init__(self, init_sample, fixed_dims=None): super().__init__() # fixed_dims: list of column indices to freeze self.fixed_dims = fixed_dims if fixed_dims is not None else [] free_dims = [i for i in range(init_sample.shape[1]) if i not in self.fixed_dims] self.free_dims = free_dims # Only free dims are Parameters self.params = nn.ParameterList( [nn.Parameter(row[free_dims].clone().detach()) for row in init_sample] ) # Fixed dims stored as buffer (not optimized, but accessible) self.register_buffer('fixed_vals', init_sample[:, fixed_dims].clone().detach())
[docs] def stacked(self): free = torch.stack(list(self.params), dim=0) # [n, n_free] full = torch.zeros(free.shape[0], free.shape[1] + len(self.fixed_dims), device=free.device, dtype=free.dtype) full[:, self.free_dims] = free if len(self.fixed_dims) > 0: full[:, self.fixed_dims] = self.fixed_vals return full
[docs] def update_params(self, new_values, mask=None): """ new_values: [batch_size, param_dim] tensor mask: optional boolean mask [batch_size] to only update accepted samples """ with torch.no_grad(): if mask is None: for i, val in enumerate(new_values): self.params[i].copy_(val) else: # Only update the specific batch indices that were accepted indices = torch.where(mask)[0] for idx in indices: self.params[idx].copy_(new_values[idx])
[docs] def gradient_descent_optimization( # todo consolidate kwargs somewhere init_sample: torch.Tensor, init_crystal_batch, optimizer_func: Union[torch.optim.Optimizer, str], init_lr: Optional[float] = 1e-1, max_num_steps: Optional[int] = 500, show_tqdm: Optional[bool] = False, convergence_eps: float = 1e-3, grad_norm_clip: Optional[float] = 0.1, optim_target: Optional[str] = 'lj', score_model: Optional[torch.nn.Module] = None, target_packing_coeff: Optional[torch.Tensor] = None, do_box_restriction: Optional[bool] = False, cutoff: Optional[float] = None, compression_factor: Optional[float] = 0, enforce_reduced: Optional[bool] = False, supercell_size: Optional[int] = 10, anneal_lr: Optional[bool] = False, predictor: Optional = None, mc_kT: float = 0.5, log_noise_range=[-2, -1], target_rdf: torch.tensor = None, target_latent: torch.tensor = None, elementwise: bool = True, atomwise: bool = False, repulsion: float = 1.0, umbrella: Optional[bool] = False, # do umbrella sampling in latent space umbrella_sigma: Optional[float] = None, # bandwidth term for umbrella sampling umbrella_epsilon: Optional[float] = None, # repulsion term for umbrella sampling umbrella_record: Optional[list] = None, rdf_warmup: Optional[torch.tensor] = 500, ): """ do a local optimization via gradient descent on some score function """ # todo implement wrapping over periodic latent DoF if cutoff is None: # lennard jones need 10 angstroms to nicely converge cutoff = 10 energy_computes = ['lj', 'elj'] min_num_steps = 50 num_samples = init_crystal_batch.num_graphs target_map = { 'silu': 'silu', 'qlj': 'qlj', 'ellipsoid': 'ellipsoid', 'reduce': 'reduction_en', 'uma': 'uma', 'mace': 'mace', 'rdf_dist': 'rdf', 'elj': 'elj', 'lj':'lj' } energy_computes.append(target_map[optim_target.lower()]) energy_computes = list(set(energy_computes)) if enforce_reduced: # use an energy based penalty to enforce sampling in the reduced subspace energy_computes.append('reduction_en') std_margin = 0.01 else: std_margin = 0 if target_rdf is not None: target_rdf = target_rdf.to(init_sample.device) if target_latent is not None: target_latent = target_latent.to(init_sample.device) if False: # target_rdf is not None: # assumes we already have the box fixed_dims = [0, 1, 2, 3, 4, 5] else: fixed_dims = None param_module = CrystalParams(init_sample, fixed_dims=fixed_dims) loss_config = dict2namespace({ 'cutoff': cutoff, 'target_rdf': target_rdf, 'score_model': score_model, 'optim_target': optim_target, 'target_latent': target_latent, 'rdf_warmup': rdf_warmup, }) aux_config = dict2namespace({ 'compression_factor': compression_factor, 'do_box_restriction': do_box_restriction, 'target_packing_coeff': target_packing_coeff, 'enforce_reduced': enforce_reduced, 'umbrella': umbrella, 'umbrella_sigma': umbrella_sigma, 'umbrella_epsilon': umbrella_epsilon, 'umbrella_record': [] if umbrella_record is None else umbrella_record, }) optimizer = init_opt(init_lr, optimizer_func, param_module) monte_carlo = optimizer_func.lower() == 'mcmc' if monte_carlo: # We need a starting energy for the very first comparison assert False, "Rewrite with new args" with torch.no_grad(): T, current_energy, current_state = init_mc_state(compression_factor, cutoff, do_box_restriction, energy_computes, enforce_reduced, init_crystal_batch, mc_kT, optim_target, param_module, score_model, std_margin, supercell_size, target_packing_coeff, predictor, target_rdf, target_latent) if anneal_lr: lr_factor = get_annealing_factor(1, 0.01, int(max_num_steps * 0.75), 1) else: lr_factor = 1 scheduler1 = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: lr_factor) params_record = torch.zeros((max_num_steps, init_crystal_batch.num_graphs, init_sample.shape[-1]), dtype=torch.float32) if optim_target.lower() == 'ellipsoid': init_crystal_batch.load_ellipsoid_model() ellipsoid_model = copy.deepcopy(init_crystal_batch.ellipsoid_model) ellipsoid_model = ellipsoid_model.to(init_crystal_batch.device) ellipsoid_model.eval() # unify initialization of this vs other score models, uma models, etc. records = None converged = torch.zeros(init_crystal_batch.num_graphs, dtype=torch.bool) # torch.autograd.set_detect_anomaly(True) # for debugging try: with torch.set_grad_enabled(not monte_carlo): with tqdm(total=max_num_steps, disable=not show_tqdm) as pbar: s_ind = 0 while (s_ind < (max_num_steps - 1)) and not converged.all(): optimizer.zero_grad(set_to_none=True) crystal_batch = init_crystal_batch.clone().detach() if monte_carlo: crystal_batch.set_cell_parameters(param_module.stacked(), skip_box_analysis=False) noised_latents = get_mc_noised_samples(crystal_batch, log_noise_range) crystal_batch.latent_to_cell_params(noised_latents, skip_box_analysis=True, skip_enforce_crystal_system=True) else: crystal_batch.set_cell_parameters(param_module.stacked(), skip_box_analysis=True) crystal_batch.clean_cell_parameters( mode='hard', canonicalize_orientations=True, ) # box analysis included in here outputs, cluster_batch = crystal_batch.analyze( computes=energy_computes, cutoff=cutoff, rdf_cutoff=cutoff, supercell_size=supercell_size, return_cluster=True, repulsion=1, surface_padding=0, std_orientation=True, margin=std_margin, predictor=predictor, elementwise=elementwise, atomwise=atomwise, assign_outputs=True, bins=100, ) """ record some stats""" records = update_record(crystal_batch, outputs, params_record, records, s_ind) """loss and backprop""" if monte_carlo: assert False, "Update arguments" parse_mc_step(T, cluster_batch, compression_factor, crystal_batch, current_energy, current_state, do_box_restriction, enforce_reduced, optim_target, outputs, param_module, records, score_model, target_packing_coeff, target_rdf, target_latent, cutoff) else: loss_and_backprop(cluster_batch, crystal_batch, grad_norm_clip, optimizer, outputs, param_module, records, loss_config, aux_config, s_ind) if s_ind % 10 == 0: gc.collect() torch.cuda.empty_cache() scheduler1.step() # shrink s_ind += 1 if s_ind % 10 == 0: pbar.update(10) if s_ind >= min(max_num_steps, max(50, min_num_steps)): converged = check_convergence(params_record, s_ind, convergence_eps, optimizer, init_lr) except (RuntimeError, ValueError) as e: if is_cuda_oom(e): if s_ind > 0: records = {k: torch.stack(v) for k, v in records.items()} best_sample_ind = torch.argmin(records['loss'], dim=0).flatten() # pick the best sample from each trajectory best_samples = torch.stack([params_record[best_sample_ind[ind], ind] for ind in range(num_samples)]) torch.save(best_samples, 'opt_intermediates.pt') raise e # we want to raise this to fire the oom catcher above else: raise e records = {k: torch.stack(v) for k, v in records.items()} """ Pull out and re-analyze the best samples from each trajectory """ best_sample_ind = torch.argmin(records['loss'], dim=0).flatten() # pick the best sample from each trajectory best_samples = torch.stack([params_record[best_sample_ind[ind], ind] for ind in range(num_samples)]) crystal_batch = init_crystal_batch.clone().detach() # this is necessary to not retain lots of intermediate tensors crystal_batch.set_cell_parameters(best_samples.to(crystal_batch.device), skip_box_analysis=False) _ = crystal_batch.analyze( computes=energy_computes, cutoff=cutoff, rdf_cutoff=cutoff, supercell_size=supercell_size, return_cluster=False, repulsion=repulsion, surface_padding=0, predictor=predictor, assign_outputs=True, elementwise=False, atomwise=True, bins=100 ) samples_list = crystal_batch.batch_to_list() if enforce_reduced: penalty = crystal_batch.compute_cell_reduction_penalty() samples_list = [elem for i, elem in enumerate(samples_list) if penalty[i] < 1e-3] """ # analyze trajectory information, if we want from mxtaltools.common.utils import log_rescale_positive timesteps = torch.arange(s_ind).repeat(init_crystal_batch.num_graphs, 1).T traj_fig(timesteps, log_rescale_positive(records['loss']), names=['time', 'loss']) traj_fig(timesteps, (records['cp']), names=['time', 'cp']) traj_fig(records['cp'], log_rescale_positive(records['lj']), names=['cp', 'lj']) traj_fig(timesteps, torch.log(records['loss']), names=['time', 'loss']) crystal_batch.plot_batch_cell_params() crystal_batch.plot_batch_density_funnel() import plotly.graph_objects as go penalty = crystal_batch.compute_cell_reduction_penalty() go.Figure(go.Histogram(x=(penalty+1e-6).log10().cpu().detach().numpy())).show() """ if optim_target == 'rdf_dist': timesteps = torch.arange(s_ind).repeat(init_crystal_batch.num_graphs, 1).T traj_fig(timesteps, torch.log(records['loss']), names=['time', 'loss']) if torch.amin(records['loss'][-1]).log() < -2.5: print("Found the crystal!") good_ind = torch.argmin(records['loss'][-1]).item() sample = crystal_batch.batch_to_list()[good_ind] torch.save(sample.cpu().detach(), f'{crystal_batch.identifier[0]}_standardized_match.pt') assert False, "Found what we were looking for" # timesteps = torch.arange(s_ind).repeat(init_crystal_batch.num_graphs, 1).T # traj_fig(timesteps, torch.log(records['loss']), names=['time', 'loss']) records['params'] = params_record[:s_ind] return samples_list, records
""" from mxtaltools.common.ase_interface import ase_mol_from_crystaldata mol = ase_mol_from_crystaldata(cluster_batch, torch.argmin(records['loss'][-1]).item(), mode='unit cell') from ase.io import write write('aaa.cif',mol) """ """ # params trajectories from plotly.subplots import make_subplots fig = make_subplots(rows=4, cols=3, subplot_titles = ['a','b','c','al','be','ga','x','y','z','u','v','w']) for ind in range(12): row = ind // 3 + 1 col = ind % 3 + 1 for ind2 in range(params_record.shape[1]): fig.add_scatter(y=params_record[:s_ind, ind2, ind], showlegend=False, row=row, col=col) fig.show() # traj vis t_ind = torch.argmin(records['loss'][-1]).item() tbatch = collate_data_list([samples_list[0].clone().cpu() for _ in range(s_ind)]) tbatch.set_cell_parameters(params_record[:s_ind, t_ind, :]) tbatch.pose_aunit() tbatch.build_unit_cell() tbatch.visualize(mode='unit cell') """
[docs] def update_record(crystal_batch, outputs, params_record, records, s_ind): params_record[s_ind] = crystal_batch.full_cell_parameters().detach().cpu() # must put this before the .backward() if records is None: records = {key: [] for key in outputs if key != 'rdf'} records['cp'] = [] records['loss'] = [] for key, value in outputs.items(): if torch.is_tensor(value): records[key].append(value.detach().cpu()) records['cp'].append(crystal_batch.packing_coeff.detach().cpu()) return records
[docs] def loss_and_backprop(cluster_batch, crystal_batch, grad_norm_clip, optimizer, outputs, param_module, records, loss_config, aux_config, opt_step): loss = compute_loss(cluster_batch, crystal_batch, outputs, loss_config, opt_step) loss = compute_auxiliary_loss(cluster_batch, loss, outputs, aux_config) records['loss'].append(loss.detach().cpu()) loss_to_backprop = loss.mean() # save some effort in backprop loss_to_backprop.backward() torch.nn.utils.clip_grad_norm_(param_module.parameters(), grad_norm_clip) # gradient clipping optimizer.step() # apply grad del loss_to_backprop, cluster_batch, outputs, crystal_batch, loss
[docs] def get_mc_noised_samples(crystal_batch, log_noise_range): samples = crystal_batch.latent_params() rand_dir = torch.randn_like(samples) rand_dir = rand_dir / rand_dir.norm(dim=-1, keepdim=True) # rand_magnitude = torch.randn(len(samples), device=samples.device).abs() * noise_level u = torch.rand(len(samples), device=samples.device) rand_magnitude = 10 ** (log_noise_range[0] + (log_noise_range[1] - log_noise_range[0]) * u) noised_samples = (samples + rand_dir * rand_magnitude[:, None]).clip(min=-1, max=1) return noised_samples
[docs] def parse_mc_step(T, cluster_batch, compression_factor, crystal_batch, current_energy, current_state, do_box_restriction, enforce_reduced, optim_target, outputs, param_module, records, score_model, target_packing_coeff, target_rdf, target_latent, cutoff): # 1. Compute energy of the proposal (noised_samples) proposal_energy = compute_loss(cluster_batch, crystal_batch, optim_target, outputs, score_model, target_rdf, target_latent, cutoff) proposal_energy = compute_auxiliary_loss(cluster_batch, compression_factor, do_box_restriction, proposal_energy, target_packing_coeff, enforce_reduced, outputs) # 2. Calculate Metropolis Criterion # For minimization: P(accept) = exp(-(E_prop - E_curr) / T) delta_e = proposal_energy - current_energy acceptance_prob = torch.exp(-delta_e / T).clamp(max=1.0) # 3. Batch-wise Acceptance accepted = torch.rand_like(acceptance_prob) < acceptance_prob # 4. Update the "Current" State current_energy[accepted] = proposal_energy[accepted] current_state[accepted] = crystal_batch.full_cell_parameters()[accepted] # 5. Sync the param_module so the next iteration starts from the accepted state param_module.update_params(current_state) # You'll need a setter in your module # Record the energy for your stats records['loss'].append(current_energy.detach().cpu()) # 6. Cooling (Simulated Annealing) # T *= cooling_rate del proposal_energy, cluster_batch, outputs, crystal_batch
[docs] def init_mc_state(compression_factor, cutoff, do_box_restriction, energy_computes, enforce_reduced, init_crystal_batch, mc_kT, optim_target, param_module, score_model, std_margin, supercell_size, target_packing_coeff, uma_predictor, target_rdf, target_latent): # Initialize param_module and get initial energy crystal_batch = init_crystal_batch.clone().detach() crystal_batch.set_cell_parameters(param_module.stacked()) outputs, cluster_batch = crystal_batch.analyze(computes=energy_computes, cutoff=cutoff, rdf_cutoff=cutoff, supercell_size=supercell_size, return_cluster=True, repulsion=1, surface_padding=0, std_orientation=True, margin=std_margin, predictor=uma_predictor, ) current_energy = compute_loss(cluster_batch, crystal_batch, optim_target, outputs, score_model, cutoff, target_rdf, target_latent) current_energy = compute_auxiliary_loss(cluster_batch, compression_factor, do_box_restriction, current_energy, target_packing_coeff, enforce_reduced, outputs) current_state = param_module.stacked().clone() # Temperature schedule (Initial T should be ~ typical delta energy) T = mc_kT return T, current_energy, current_state
[docs] def init_opt(init_lr, optimizer_func, param_module): if isinstance(optimizer_func, str): if optimizer_func.lower() == 'sgd': optimizer_func = optim.SGD elif optimizer_func.lower() == 'adam': optimizer_func = optim.Adam elif optimizer_func.lower() == 'rprop': optimizer_func = optim.Rprop elif optimizer_func.lower() == 'rmsprop': optimizer_func = optim.RMSprop elif optimizer_func.lower() == 'adamw': optimizer_func = optim.AdamW elif optimizer_func.lower() == 'adadelta': optimizer_func = optim.Adadelta elif optimizer_func.lower() == 'nadam': optimizer_func = optim.Nadam elif optimizer_func.lower() == 'adagrad': optimizer_func = optim.Adagrad elif optimizer_func.lower() == 'adadelta': optimizer_func = optim.Adadelta elif optimizer_func.lower() == 'asgd': optimizer_func = optim.ASGD elif optimizer_func.lower() == 'mcmc': optimizer_func = optim.SGD # this is a dummy else: assert False, "Must pass an optimizer or an implemented string" optimizer = optimizer_func(param_module.parameters(), lr=init_lr) return optimizer
[docs] def traj_fig(x, y, names=[None, None], yrange=None, xrange=None): import plotly.graph_objects as go fig = go.Figure() for ind in range(x.shape[1]): fig.add_scatter(x=x[:, ind], y=y[:, ind], name=f"Run {ind}", opacity=0.5) fig.add_scatter(x=x[-1, :], y=y[-1, :], mode='markers', marker_size=20, marker_color='black', name='Final State') fig.add_scatter(x=x[0, :], y=y[0, :], mode='markers', marker_size=20, marker_color='grey', name='Initial State') fig.update_layout(xaxis_title=names[0], yaxis_title=names[1], xaxis_range=xrange, yaxis_range=yrange) fig.show()
[docs] def ema_trajectory(traj: torch.Tensor, alpha: float = 0.1) -> torch.Tensor: """ Vectorized EMA along time (dim=0). traj: [T, N, D] """ T = traj.size(0) w = (1 - alpha) ** torch.arange(T, device=traj.device, dtype=traj.dtype) # [T] w = w.flip(0).view(T, 1, 1) # decay weights # weighted cumulative sum numer = torch.cumsum(alpha * traj * w, dim=0) denom = torch.cumsum(w, dim=0) return numer / denom
[docs] def compute_loss(cluster_batch, crystal_batch, outputs, config, opt_step): if config.optim_target.lower() == 'lj': # todo obviate this with analysis keys loss = outputs['lj'] elif config.optim_target.lower() == 'silu': loss = outputs['silu'] elif config.optim_target.lower() == 'qlj': loss = outputs['qlj'] elif config.optim_target.lower() == 'elj': loss = outputs['elj'] elif config.optim_target.lower() == 'inter_overlaps': # force molecules apart by separating their centroids loss = inter_overlap_loss(cluster_batch, crystal_batch) elif config.optim_target.lower() == 'classification_score': loss = -softmax_and_score(config.score_model(cluster_batch, force_edges_rebuild=False)[:, :2]) elif config.optim_target.lower() == 'rdf_score': loss = F.softplus(config.score_model(cluster_batch, force_edges_rebuild=False)[:, 2]) elif config.optim_target.lower() == 'ellipsoid': loss = outputs['ellipsoid'] elif config.optim_target.lower() == 'reduce': loss = outputs['reduction_en'] elif config.optim_target.lower() == 'uma': loss = outputs['uma'] elif config.optim_target.lower() == 'mace': loss = outputs['mace'] elif config.optim_target.lower() == 'rdf_dist': n_channels = config.target_rdf.shape[-2] # 120 if config.rdf_warmup is not None: channel_warmup = config.rdf_warmup channel_onsets = torch.linspace(0, channel_warmup, n_channels) # evenly spaced turn-on times channel_weights = torch.sigmoid((opt_step - channel_onsets) / (channel_warmup / n_channels * 0.5)) # n_waves = 3 # base_periods = torch.tensor([1.0, 1.6, 2.5]) * config.rdf_warmup # channel_idx = torch.arange(n_channels, dtype=torch.float32) # # modulation = torch.zeros(n_channels) # for i in range(n_waves): # modulation += torch.sin( # 2 * torch.pi * opt_step / base_periods[i] + 2 * torch.pi * channel_idx / n_channels * (i + 1)) # modulation = modulation / n_waves # channel_weights = 0.5 + 0.5 * modulation else: channel_weights = torch.ones(n_channels) rdf_loss = compute_rdf_distance(outputs['rdf'][0], config.target_rdf, torch.linspace(0, config.cutoff, config.target_rdf.shape[-1]), channel_weights=channel_weights) en_cut = -307 # set equal or higher to the target energy beta = 5 lj_en = outputs['elj'] lj_loss = F.softplus(beta * (lj_en - en_cut)) / beta loss = rdf_loss + lj_loss / 100 elif config.optim_target.lower() == 'latent_dist': loss = (config.target_latent - crystal_batch.latent_params()).norm(dim=-1) return loss
[docs] def compute_auxiliary_loss(cluster_batch, loss, outputs, config): if config.target_packing_coeff is not None: cp_loss = (cluster_batch.packing_coeff - config.target_packing_coeff) ** 2 loss = loss + cp_loss if config.do_box_restriction: # enforce box shape cannot become too long (3 mol radii) or narrow (3 angstroms) in any direction # repulsive from about range 3 #(80000/aunit_lengths**12 + 10*aunit_lengths - 31.25).sum(dim=1) # forces boxes to be larger than 3 angstroms, but squeezes them otherwise aunit_lengths = cluster_batch.scale_lengths_to_aunit() box_loss = (80000 / aunit_lengths ** 12).sum(dim=1) loss = loss + box_loss if config.compression_factor != 0: aunit_lengths = cluster_batch.scale_lengths_to_aunit() threshold = 0.65 width = 0.01 # controls sharpness of transition gate = torch.sigmoid((threshold - cluster_batch.packing_coeff) / width) pressure_loss = aunit_lengths.sum(dim=1) * config.compression_factor * gate loss = loss + pressure_loss if config.enforce_reduced: penalty = F.relu(outputs['reduction_en']) loss = loss + 10000 * penalty # severely punish reduction violations if config.umbrella: if len(config.umbrella_record) == 0: pass else: latents = cluster_batch.latent_params() record = config.umbrella_record.to(cluster_batch.device) dists = torch.cdist(latents, record) penalty = torch.exp(-dists ** 2 / (2 * config.umbrella_sigma ** 2)).sum(dim=1).clip(max=10) loss = loss + config.umbrella_epsilon * penalty return loss
[docs] def inter_overlap_loss(cluster_batch): # enforce molecules far enough away that they cannot possibly overlap # intermolecular centroid range repulsion assert cluster_batch.z_prime.amax() == 1, "molwise business not implemented for Z'>1" _, atoms_per_cluster = torch.unique(cluster_batch.batch, return_counts=True) mols_per_cluster = (atoms_per_cluster / cluster_batch.num_atoms).long() molwise_batch = torch.arange(cluster_batch.num_graphs, device=cluster_batch.device).repeat_interleave( mols_per_cluster, dim=0) flat_mol_inds = cluster_batch.mol_ind + torch.cat( [torch.zeros(1, device=cluster_batch.device, dtype=torch.long), torch.cumsum(mols_per_cluster, dim=0)]).long()[:-1].repeat_interleave( torch.diff(cluster_batch.ptr), dim=0) mol_centroids = scatter(cluster_batch.pos, flat_mol_inds, reduce='mean', dim=0, dim_size=flat_mol_inds[-1] + 1) edge_i, edge_j = gnn.radius_graph( x=mol_centroids, batch=molwise_batch, max_num_neighbors=10, r=float(cluster_batch.radius.amax() * 2)) inter_dists = torch.linalg.norm( mol_centroids[edge_i] - mol_centroids[edge_j], dim=1 ) scaled_inter_dists = inter_dists / cluster_batch.radius[molwise_batch[edge_i]] edgewise_losses = F.relu(-(scaled_inter_dists - 2)) # push molecules apart if they are within each others' radii loss = scatter(edgewise_losses, molwise_batch[edge_i], dim_size=cluster_batch.num_graphs, dim=0, reduce='sum') return loss
""" # viz sample distribution import plotly.graph_objects as go ljs = torch.tensor([elem.scaled_lj for elem in opt_samples[:len(samples_record[0])]]) ljs2 = torch.tensor([elem.scaled_lj for elem in opt_samples]) fig = go.Figure() fig.add_histogram(x=ljs, nbinsx=100, histnorm='probability density') fig.add_histogram(x=ljs2, nbinsx=100, histnorm='probability density') fig.show() """ """ import plotly.graph_objects as go fig = go.Figure() lj_pots = torch.stack( [torch.tensor([sample.scaled_lj for sample in sample_list]) for sample_list in samples_record]) es_pots = torch.stack([torch.tensor([sample.es_pot for sample in sample_list]) for sample_list in samples_record]) orig_batch = collate_data_list(opt_samples) lj_pots = torch.cat([orig_batch.scaled_lj[None, :], lj_pots], dim=0) es_pots = torch.cat([orig_batch.es_pot[None, :], es_pots], dim=0) for ind in range(lj_pot.shape[-1]): fig.add_scatter(y=lj_pots[..., ind], marker_color='blue', name='lj', legendgroup='lg', showlegend=True if ind == 0 else False) fig.add_scatter(y=es_pots[..., ind] * 10, marker_color='red', name='es', legendgroup='es', showlegend=True if ind == 0 else False) fig.add_scatter(y=es_pots[..., ind] * 10 + lj_pots[..., ind], marker_color='green', name='combo', legendgroup='combo', showlegend=True if ind == 0 else False) fig.add_scatter(y=(lj_pots).mean(1), marker_color='black', name='lj', legendgroup='lg', showlegend=True if ind == 0 else False) fig.add_scatter(y=(es_pots * 10).mean(1), marker_color='black', name='es', legendgroup='es', showlegend=True if ind == 0 else False) fig.add_scatter(y=(es_pots * 10 + lj_pots).mean(1), marker_color='black', name='combo', legendgroup='combo', showlegend=True if ind == 0 else False) fig.show() """ def _init_for_local_opt(lr, max_num_steps, optimizer_func, sample, num_atoms): optimizer = optimizer_func([sample], lr=lr) max_lr_target_time = max_num_steps // 10 max_lr = lr * 100 grow_lambda = (max_lr / lr) ** (1 / max_lr_target_time) scheduler1 = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.975) scheduler2 = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: grow_lambda) hit_max_lr = False num_samples = len(sample) vdw_record = torch.zeros((max_num_steps, num_samples)) samples_record = torch.zeros((max_num_steps, num_samples, 12)) raw_samples_record = torch.zeros_like(samples_record) handedness_record = torch.zeros((max_num_steps, num_samples)) loss_record = torch.zeros_like(vdw_record) lr_record = torch.zeros(max_num_steps) packing_record = torch.zeros_like(vdw_record) aunit_poses = torch.zeros((len(vdw_record), num_atoms, 3)) return (hit_max_lr, loss_record, lr_record, max_lr, optimizer, packing_record, samples_record, raw_samples_record, handedness_record, scheduler1, scheduler2, vdw_record, aunit_poses) # # def cleanup_sample(raw_sample, sg_ind_list, symmetries_dict): # # force outputs into physical ranges # # cell lengths have to be positive nonzero # cell_lengths = raw_sample[:, :3].clip(min=0.01) # # range from (0,pi) with 20% padding to prevent too-skinny cells # cell_angles = enforce_1d_bound(raw_sample[:, 3:6], x_span=torch.pi / 2 * 0.8, x_center=torch.pi / 2, # mode='hard') # # positions must be on 0-1 # mol_positions = enforce_1d_bound(raw_sample[:, 6:9], x_span=0.5, x_center=0.5, mode='hard') # # for now, just enforce vector norm # rotvec = raw_sample[:, 9:12] # norm = torch.linalg.norm(rotvec, dim=1) # new_norm = enforce_1d_bound(norm, x_span=0.999 * torch.pi, x_center=torch.pi, mode='hard') # MUST be nonzero # new_rotvec = rotvec / norm[:, None] * new_norm[:, None] # # invert_inds = torch.argwhere(new_rotvec[:, 2] < 0) # # new_rotvec[invert_inds] = -new_rotvec[invert_inds] # z direction always positive # # force cells to conform to crystal system # cell_lengths, cell_angles = enforce_crystal_system(cell_lengths, cell_angles, sg_ind_list, # symmetries_dict) # sample = torch.cat((cell_lengths, cell_angles, mol_positions, new_rotvec), dim=-1) # return sample #
[docs] def check_convergence(params_record, s_ind, convergence_eps, optimizer, init_lr): smoothed = ema_trajectory(params_record[:s_ind]) diffs = smoothed[s_ind - 50:s_ind, :, :].diff(dim=0).abs().mean((0, 2)) converged = diffs < convergence_eps if (converged.float().mean() > 0.95): converged.fill_(True) return converged