Source code for mxtaltools.dataset_utils.utils

import torch.nn.functional as F
from random import shuffle
from typing import Optional

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torch_geometric.data import DataLoader, Batch
from torch_geometric.loader.dataloader import DataLoader


[docs] def collate_data_list(data_list, exclude_unit_cell: bool = True, max_z_prime: Optional[int] = None, exclude_keys: Optional[list] = None, skip_default_exclusion: bool = False): if len(data_list) == 0: assert False, "Data list is empty!" if hasattr(data_list, 'is_batch'): if data_list.is_batch: print("Already batched") return data_list if not isinstance(data_list, list): data_list = [data_list] # Optionally exclude known keys # todo this really needs to be fixed up if not skip_default_exclusion: exclude_keys_i = ['edges_dict', 'niggli_energy', 'core_energy', 'density_energy', 'lj_energy', 'bounding_energy', 'es_pot', 'gfn_energy', 'lj_pot', 'normed_lj_pot', 'qlj_pot', 'elj_pot', 'silu_pot', ] else: exclude_keys_i = [] if exclude_keys is not None: exclude_keys_i.extend(exclude_keys) if exclude_unit_cell: exclude_keys_i.append('unit_cell_pos') if hasattr(data_list[0], 'max_z_prime'): batch_max_z_prime = max(int(d.max_z_prime) for d in data_list) for d in data_list: d_zp = int(d.max_z_prime) if d_zp < batch_max_z_prime: pad = batch_max_z_prime - d_zp d.aunit_centroid = F.pad(d.aunit_centroid, (0, 3 * pad)) d.aunit_orientation = F.pad(d.aunit_orientation, (0, 3 * pad)) d.aunit_handedness = F.pad(d.aunit_handedness, (0, pad)) d.max_z_prime = batch_max_z_prime batch = Batch.from_data_list(data_list, exclude_keys=list(exclude_keys_i), ) # if hasattr(batch, 'max_z_prime'): # if isinstance(batch.max_z_prime, int): # pass # else: if 'crystal' in batch.__str__().lower(): if max_z_prime is not None: batch.max_z_prime = max_z_prime else: batch.max_z_prime = int(batch.z_prime.amax()) if hasattr(batch, 'aunit_centroid'): batch.aunit_centroid = batch.aunit_centroid[:, :3 * batch.max_z_prime] batch.aunit_orientation = batch.aunit_orientation[:, :3 * batch.max_z_prime] batch.aunit_handedness = batch.aunit_handedness[:, :batch.max_z_prime] assert batch.aunit_centroid.shape[ 1] == 3 * batch.max_z_prime, "Batch max z prime must agree with parameterization" if not hasattr(batch, 'symmetry_operators'): batch.reset_sg_info(batch.sg_ind) return batch
[docs] def quick_combine_dataloaders(dataset, data_loader, batch_size, max_size): shuffle(data_loader.dataset) # randomize order of old dataset dataset.extend(data_loader.dataset) # append old dataset to new one dataset = dataset[:max_size] # truncate from the end of the old dataset dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=data_loader.num_workers, pin_memory=data_loader.pin_memory, drop_last=data_loader.drop_last) return dataloader
[docs] def quick_combine_crystal_embedding_dataloaders(dataset, data_loader, batch_size, max_size): x, y = data_loader.dataset.x, data_loader.dataset.y # randomize order of old dataset rands = torch.tensor(np.random.choice(len(x), len(x), replace=False), dtype=torch.long) x = x[rands] y = y[rands] new_x, new_y = dataset.x, dataset.y # prepend the new dataset new_x = torch.cat((new_x, x), dim=0)[:max_size] new_y = torch.cat((new_y, y), dim=0)[:max_size] new_dataset = SimpleDataset(x=new_x, y=new_y) dataloader = DataLoader(new_dataset, batch_size=batch_size, shuffle=True, num_workers=data_loader.num_workers, pin_memory=data_loader.pin_memory, drop_last=data_loader.drop_last) return dataloader
[docs] def filter_graph_nodewise(data, keep_index=None, delete_index=None): """ # NOTE this does not work because of our custom data structure Function to get subgraph of data. Effectively filtering by nodes. Args: data: pyg data batch keep_index: boolean or indexes of which nodes should be kept Returns: Parameters ---------- data keep_index delete_index """ assert keep_index is not None or delete_index is not None if keep_index is None and delete_index is not None: keep_index = [ind for ind in range(len(data)) if ind not in delete_index] if data.edge_index is None: data.edge_index = torch.arange(2) # necessary dummy return data.subgraph(keep_index)
[docs] def basic_stats(values: torch.tensor) -> dict[str, Tensor]: clipped_values = values.clip(min=torch.quantile(values[:int(16e6)].float(), 0.05), max=torch.quantile(values[:int(16e6)].float(), 0.95)) return {'max': torch.amax(values), 'min': torch.amin(values), 'mean': torch.mean(values.float()), 'std': torch.std(values.float()), 'tight_mean': torch.mean(clipped_values), 'tight_std': torch.std(clipped_values), 'histogram': torch.histogram(values.float(), bins=50), 'uniques': torch.unique(values, return_counts=True) if values.dtype == torch.long else (None, None), }
[docs] def get_dataloaders(dataset_builder, machine, batch_size, test_fraction=0.2, shuffle=True, num_workers: int = 0): batch_size = batch_size train_size = int((1 - test_fraction) * len(dataset_builder)) # split data into training and test sets test_size = len(dataset_builder) - train_size train_dataset = [] test_dataset = [] for i in range(test_size, test_size + train_size): train_dataset.append(dataset_builder[i]) for i in range(test_size): test_dataset.append(dataset_builder[i]) if machine == 'cluster': # faster dataloading on cluster with more workers if len(train_dataset) > 0: tr = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=False) else: tr = None te = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=False) else: if len(train_dataset) > 0: tr = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=True, drop_last=False) else: tr = None te = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=True, drop_last=False) return tr, te
[docs] def update_dataloader_batch_size(loader, new_batch_size): return DataLoader(loader.dataset, batch_size=new_batch_size, shuffle=True, num_workers=loader.num_workers, pin_memory=loader.pin_memory, drop_last=loader.drop_last)
[docs] class SimpleDataset(Dataset): def __init__(self, x, y): self.x = torch.tensor(x, dtype=torch.float32) self.y = torch.tensor(y, dtype=torch.float32) def __len__(self): return len(self.x) def __getitem__(self, idx): return self.x[idx], self.y[idx]