"""In the context of the `OC20 dataset <https://opencatalystproject.org/index.html>`_,
rewire each 3D molecular graph according to 1 of 3 strategies: remove all tag-0 atoms,
aggregate all tag-0 atoms into a single super-node, or aggregate all tag-0 atoms of
a given element into a single super-node (hence, up to 3 super nodes will be created
since 0C20 catalysts can have up to 3 elements).
.. image:: https://raw.githubusercontent.com/vict0rsch/phast/main/examples/data/rewiring.png
:alt: graph rewiring
:width: 600px
.. code-block:: python
from phast.graph_rewiring import remove_tag0_nodes
data = load_oc20_data_batch() # Yours to define
rewired_data = remove_tag0_nodes(data)
.. warning::
This modules expects ``torch_geometric`` to be installed.
"""
from copy import deepcopy
import torch
from torch import cat, isin, tensor, where
from phast.utils import ensure_pyg_ok
from typing import Union
try:
from torch_geometric.utils import coalesce, remove_self_loops, sort_edge_index
from torch_geometric.data import Batch, Data
except ImportError:
pass
@ensure_pyg_ok
[docs]def remove_tag0_nodes(data: Union[Batch, Data]) -> Union[Batch, Data]:
"""
Delete sub-surface (``data.tag == 0``) nodes and rewire the graph accordingly.
.. warning::
This function modifies the input data in-place.
Expected ``data`` tensor attributes:
- ``pos``: node positions
- ``atomic_numbers``: atomic numbers
- ``batch``: mini-batch id for each atom
- ``tags``: atom tags
- ``edge_index``: edge indices as a $2 \times E$ tensor
- ``force``: force vectors per atom (optional)
- ``pos_relaxed``: relaxed atom positions (optional)
- ``fixed``: mask for fixed atoms (optional)
- ``natoms``: number of atoms per graph
- ``ptr``: cumulative sum of ``natoms``
- ``cell_offsets``: unit cell directional offset for each edge
- ``distances``: distance between each edge's atoms
Args:
data (torch_geometric.Data): the data batch to re-wire
"""
device = data.edge_index.device
# non sub-surface atoms
non_sub = where(data.tags != 0)[0]
src_is_not_sub = isin(data.edge_index[0], non_sub)
target_is_not_sub = isin(data.edge_index[1], non_sub)
neither_is_sub = src_is_not_sub * target_is_not_sub
# per-atom tensors
data.pos = data.pos[non_sub, :]
data.atomic_numbers = data.atomic_numbers[non_sub]
data.batch = data.batch[non_sub]
if hasattr(data, "force"):
data.force = data.force[non_sub, :]
if hasattr(data, "fixed"):
data.fixed = data.fixed[non_sub]
data.tags = data.tags[non_sub]
if hasattr(data, "pos_relaxed"):
data.pos_relaxed = data.pos_relaxed[non_sub, :]
# per-edge tensors
data.edge_index = data.edge_index[:, neither_is_sub]
data.cell_offsets = data.cell_offsets[neither_is_sub, :]
data.distances = data.distances[neither_is_sub]
# re-index adj matrix, given some nodes were deleted
num_nodes = data.natoms.sum().item()
mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
mask[non_sub] = 1
assoc = torch.full((num_nodes,), -1, dtype=torch.long, device=device)
assoc[mask] = torch.arange(mask.sum(), device=device)
data.edge_index = assoc[data.edge_index]
# per-graph tensors
batch_size = max(data.batch).item() + 1
data.natoms = tensor(
[(data.batch == i).sum() for i in range(batch_size)],
dtype=data.natoms.dtype,
device=device,
)
data.ptr = tensor(
[0] + [data.natoms[:i].sum() for i in range(1, batch_size + 1)],
dtype=data.ptr.dtype,
device=device,
)
_, data.neighbors = torch.unique(
data.batch[data.edge_index[0, :]], return_counts=True
)
return data
@ensure_pyg_ok
[docs]def one_supernode_per_graph(
data: Union[Batch, Data], cutoff: float = 6.0, num_elements: int = 83
) -> Union[Batch, Data]:
"""
Replaces all tag-0 atom with a single super-node $S$ representing them, per graph.
For each graph, $S$ is the last node in the graph.
$S$ is positioned at the center of mass of all tag-0 atoms in $x$ and $y$ directions
but at the maximum $z$ coordinate of all tag-0 atoms.
All atoms previously connected to a tag-0 atom are now connected to $S$ unless
that would create an edge longer than ``cutoff``.
Expected ``data`` attributes are the same as for :func:`remove_tag0_nodes`.
.. note::
$S$ will be created with a new atomic number $Z_{S} = num\_elements + 1$,
so this should be set to the number of elements expected to be present in the
dataset, not that of the current graph.
.. warning::
This function modifies the input data in-place.
Args:
data (data.Data): single batch of graphs
"""
batch_size = max(data.batch).item() + 1
device = data.edge_index.device
original_ptr = deepcopy(data.ptr)
# ids of sub-surface nodes, per batch
sub_nodes = [
where((data.tags == 0) * (data.batch == i))[0] for i in range(batch_size)
]
# idem for non-sub-surface nodes
non_sub_nodes = [
where((data.tags != 0) * (data.batch == i))[0] for i in range(batch_size)
]
# super node index per batch: they are last in their batch
# (after removal of tag0 nodes)
new_sn_ids = [
sum([len(nsn) for nsn in non_sub_nodes[: i + 1]]) + i for i in range(batch_size)
]
# define new number of atoms per batch
data.ptr = tensor(
[0] + [nsi + 1 for nsi in new_sn_ids], dtype=data.ptr.dtype, device=device
)
data.natoms = data.ptr[1:] - data.ptr[:-1]
# Store number of nodes each supernode contains
data.subnodes = tensor(
[len(sub) for sub in sub_nodes], dtype=torch.long, device=device
)
# super node position for a batch is the mean of its aggregates
# sn_pos = [data.pos[sub_nodes[i]].mean(0) for i in range(batch_size)]
sn_pos = [
cat(
[
data.pos[sub_nodes[i], :2].mean(0),
data.pos[sub_nodes[i], 2].max().unsqueeze(0),
],
dim=0,
)
for i in range(batch_size)
]
# the super node force is the mean of the force applied to its aggregates
if hasattr(data, "force"):
sn_force = [data.force[sub_nodes[i]].mean(0) for i in range(batch_size)]
data.force = cat(
[
cat([data.force[non_sub_nodes[i]], sn_force[i][None, :]])
for i in range(batch_size)
]
)
# learn a new embedding to each supernode
data.atomic_numbers = cat(
[
cat(
[
data.atomic_numbers[non_sub_nodes[i]],
tensor([num_elements + 1], device=device),
]
)
for i in range(batch_size)
]
)
# position excludes sub-surface atoms but includes extra super-nodes
data.pos = cat(
[
cat([data.pos[non_sub_nodes[i]], sn_pos[i][None, :]])
for i in range(batch_size)
]
)
# relaxed position for supernode is the same as initial position
if hasattr(data, "pos_relaxed"):
data.pos_relaxed = cat(
[
cat([data.pos_relaxed[non_sub_nodes[i]], sn_pos[i][None, :]])
for i in range(batch_size)
]
)
# idem, sn position is fixed
if hasattr(data, "fixed"):
data.fixed = cat(
[
cat(
[
data.fixed[non_sub_nodes[i]],
tensor([1.0], dtype=data.fixed.dtype, device=device),
]
)
for i in range(batch_size)
]
)
# idem, sn have tag0
data.tags = cat(
[
cat(
[
data.tags[non_sub_nodes[i]],
tensor([0], dtype=data.tags.dtype, device=device),
]
)
for i in range(batch_size)
]
)
# Edge-index and cell_offsets
batch_idx_adj = data.batch[data.edge_index][0]
ei_sn = data.edge_index.clone()
new_cell_offsets = data.cell_offsets.clone()
# number of nodes in this batch: all existing + batch_size supernodes
num_nodes = original_ptr[-1].item()
# Re-index
mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
mask[cat(non_sub_nodes)] = 1 # mask is 0 for sub-surface atoms
assoc = torch.full((num_nodes,), -1, dtype=torch.long, device=device)
assoc[mask] = cat(
[
torch.arange(data.ptr[e], data.ptr[e + 1] - 1, device=device)
for e in range(batch_size)
]
)
# re-index only edges for which not both nodes are sub-surface atoms
ei_sn = assoc[ei_sn]
# Adapt cell_offsets: add [0,0,0] for supernode related edges
is_minus_one = isin(ei_sn, tensor(-1, device=device))
new_cell_offsets[is_minus_one.any(dim=0)] = tensor([0, 0, 0], device=device)
# Replace index -1 by supernode index
ei_sn = where(
is_minus_one,
tensor(new_sn_ids, device=device)[batch_idx_adj],
ei_sn,
)
# Remove self loops
ei_sn, new_cell_offsets = remove_self_loops(ei_sn, new_cell_offsets)
# Remove tag0 related duplicates
# First, store tag 1/2 adjacency
new_non_sub_nodes = where(data.tags != 0)[0]
tag12_ei = ei_sn[:, isin(ei_sn, new_non_sub_nodes).all(dim=0)]
tag12_cell_offsets_ei = new_cell_offsets[
isin(ei_sn, new_non_sub_nodes).all(dim=0), :
]
# Remove duplicate in supernode adjacency
indxes = isin(ei_sn, tensor(new_sn_ids).to(device=ei_sn.device)).any(dim=0)
ei_sn, new_cell_offsets = coalesce(
ei_sn[:, indxes], edge_attr=new_cell_offsets[indxes, :], reduce="min"
)
# Merge back both
ei_sn = cat([tag12_ei, ei_sn], dim=1)
new_cell_offsets = cat([tag12_cell_offsets_ei, new_cell_offsets], dim=0)
ei_sn, new_cell_offsets = sort_edge_index(ei_sn, edge_attr=new_cell_offsets)
# Remove duplicate entries
# ei_sn, new_cell_offsets = coalesce(
# ei_sn, edge_attr=new_cell_offsets, reduce="min",
# )
# ensure correct type
data.edge_index = ei_sn.to(dtype=data.edge_index.dtype)
data.cell_offsets = new_cell_offsets.to(dtype=data.cell_offsets.dtype)
# distances
data.distances = torch.sqrt(
((data.pos[data.edge_index[0, :]] - data.pos[data.edge_index[1, :]]) ** 2).sum(
-1
)
).to(dtype=data.distances.dtype)
# batch
data.batch = torch.zeros(data.ptr[-1], dtype=data.batch.dtype, device=device)
for i, p in enumerate(data.ptr[:-1]):
data.batch[
torch.arange(p, data.ptr[i + 1], dtype=torch.long, device=device)
] = tensor(i, dtype=data.batch.dtype, device=device)
return adjust_cutoff_distances(data, new_sn_ids, cutoff)
@ensure_pyg_ok
[docs]def one_supernode_per_atom_type(
data: Union[Batch, Data], cutoff: float = 6.0
) -> Union[Batch, Data]:
"""
For each graph independently, replace all tag-0 atoms of a given element by a new
super node $S_i, \ i \in \{1..3\}$. As per :func:`one_supernode_per_graph`, each
$S_i$ is positioned at the center of mass of the atoms it replaces in $x$ and $y$
dimensions but at the maximum height of the atoms it replaces in the $z$ dimension.
Expected ``data`` attributes are the same as for :func:`remove_tag0_nodes`.
.. note::
$S_i$ conserves the atomic number of the tag-0 atoms it replaces.
.. warning::
This function modifies the input data in-place.
Args:
data (torch_geometric.Data): the data batch to re-wire
Returns:
torch_geometric.Data: the data rewired data batch
"""
batch_size = max(data.batch).item() + 1
device = data.edge_index.device
original_ptr = deepcopy(data.ptr)
# idem for non-sub-surface nodes
non_sub_nodes = [
where((data.tags != 0) * (data.batch == i))[0] for i in range(batch_size)
]
# atom types per supernode
atom_types = [
torch.unique(data.atomic_numbers[(data.tags == 0) * (data.batch == i)])
for i in range(batch_size)
]
# number of supernodes per batch
num_supernodes = [atom_types[i].shape[0] for i in range(batch_size)]
total_num_supernodes = sum(num_supernodes)
# indexes of nodes belonging to each supernode
supernodes_composition = [
where((data.atomic_numbers == an) * (data.tags == 0) * (data.batch == i))[0]
for i in range(batch_size)
for an in atom_types[i]
]
# Store number of nodes each supernode regroups
data.subnodes = tensor(
[len(sub) for sub in supernodes_composition], dtype=torch.long, device=device
)
# super node index per batch: they are last in their batch
# (after removal of tag0 nodes)
new_sn_ids = [
[
sum([len(nsn) for nsn in non_sub_nodes[: i + 1]]) + j
for j in range(sum(num_supernodes[:i]), sum(num_supernodes[: i + 1]))
]
for i in range(batch_size)
]
# Concat version
new_sn_ids_cat = [s for sn in new_sn_ids for s in sn]
# supernode positions
supernodes_pos = [
cat([data.pos[sn, :2].mean(0), data.pos[sn, 2].max().unsqueeze(0)], dim=0)[
None, :
]
for sn in supernodes_composition
]
# number of atoms per graph in the batch
data.ptr = tensor(
[0] + [max(nsi) + 1 for nsi in new_sn_ids],
dtype=data.ptr.dtype,
device=device,
)
data.natoms = data.ptr[1:] - data.ptr[:-1]
# batch
data.batch = cat(
[
tensor(i, device=device).expand(
non_sub_nodes[i].shape[0] + num_supernodes[i]
)
for i in range(batch_size)
]
)
# tags
data.tags = cat(
[
cat(
[
data.tags[non_sub_nodes[i]],
tensor([0], dtype=data.tags.dtype, device=device).expand(
num_supernodes[i]
),
]
)
for i in range(batch_size)
]
)
# re-index edges
num_nodes = original_ptr[-1] # + sum(num_supernodes)
mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
mask[cat(non_sub_nodes)] = 1 # mask is 0 for sub-surface atoms
assoc = torch.full((num_nodes,), -1, dtype=torch.long, device=device)
assoc[mask] = cat(
[
torch.arange(
data.ptr[e], data.ptr[e + 1] - num_supernodes[e], device=device
)
for e in range(batch_size)
]
)
# Set corresponding supernode index to subatoms
for i, sn in enumerate(supernodes_composition):
assoc[sn] = new_sn_ids_cat[i]
# Re-index
data.edge_index = assoc[data.edge_index]
# Adapt cell_offsets: add [0,0,0] for supernode related edges
data.cell_offsets[
isin(data.edge_index, tensor(new_sn_ids_cat, device=device)).any(dim=0)
] = tensor([0, 0, 0], device=device)
# Remove self loops and duplicates
data.edge_index, data.cell_offsets = remove_self_loops(
data.edge_index, data.cell_offsets
)
# Remove tag0 related duplicates
# First, store tag 1/2 adjacency
new_non_sub_nodes = where(data.tags != 0)[0]
tag12_ei = data.edge_index[:, isin(data.edge_index, new_non_sub_nodes).all(dim=0)]
tag12_cell_offsets_ei = data.cell_offsets[
isin(data.edge_index, new_non_sub_nodes).all(dim=0), :
]
# Remove duplicate in supernode adjacency
indxes = isin(
data.edge_index, tensor(new_sn_ids_cat).to(device=data.edge_index.device)
).any(dim=0)
data.edge_index, data.cell_offsets = coalesce(
data.edge_index[:, indxes], edge_attr=data.cell_offsets[indxes, :], reduce="min"
)
# Merge back both
data.edge_index = cat([tag12_ei, data.edge_index], dim=1)
data.cell_offsets = cat([tag12_cell_offsets_ei, data.cell_offsets], dim=0)
data.edge_index, data.cell_offsets = sort_edge_index(
data.edge_index, edge_attr=data.cell_offsets
)
# SNs are last in their batch
data.atomic_numbers = cat(
[
cat([data.atomic_numbers[non_sub_nodes[i]], atom_types[i]])
for i in range(batch_size)
]
)
# position exclude the sub-surface atoms but include extra super-nodes
acc_num_supernodes = [0] + [sum(num_supernodes[: i + 1]) for i in range(batch_size)]
data.pos = cat(
[
cat(
[
data.pos[non_sub_nodes[i]],
cat(
supernodes_pos[
acc_num_supernodes[i] : acc_num_supernodes[i + 1]
]
),
]
)
for i in range(batch_size)
]
)
# pos relaxed
if hasattr(data, "pos_relaxed"):
data.pos_relaxed = cat(
[
cat(
[
data.pos_relaxed[non_sub_nodes[i]],
cat(
supernodes_pos[
acc_num_supernodes[i] : acc_num_supernodes[i + 1]
]
),
]
)
for i in range(batch_size)
]
)
# the force applied on the super node is the mean of the force applied
# to its aggregates (per batch)
if hasattr(data, "force"):
sn_force = [
data.force[supernodes_composition[i]].mean(0)[None, :]
for i in range(total_num_supernodes)
]
data.force = cat(
[
cat(
[
data.force[non_sub_nodes[i]],
cat(
sn_force[acc_num_supernodes[i] : acc_num_supernodes[i + 1]]
),
]
)
for i in range(batch_size)
]
)
# fixed atoms
if hasattr(data, "fixed"):
data.fixed = cat(
[
cat(
[
data.fixed[non_sub_nodes[i]],
tensor([1.0], dtype=data.fixed.dtype, device=device).expand(
num_supernodes[i]
),
]
)
for i in range(batch_size)
]
)
# distances
# TODO: compute with cell_offsets
data.distances = torch.sqrt(
((data.pos[data.edge_index[0, :]] - data.pos[data.edge_index[1, :]]) ** 2).sum(
-1
)
)
return adjust_cutoff_distances(data, new_sn_ids_cat, cutoff)
@ensure_pyg_ok
[docs]def adjust_cutoff_distances(
data: Union[Data, Batch], sn_indxes: torch.Tensor, cutoff: float = 6.0
) -> Union[Data, Batch]:
"""
Because of rewiring, some edges could be now longer than
the allowed cutoff distance. This function removes them.
Modified attributes:
* ``edge_index``
* ``cell_offsets``
* ``distances``
* ``neighbors``
.. warning::
This function modifies the input data in-place.
Args:
data (torch_geometric.Data): The rewired graph data.
sn_indxes (torch.Tensor[torch.Long]): Indices of the supernodes.
cutoff (float, optional): Maximum edge length. Defaults to 6.0.
Returns:
torch_geometric.Data: The updated graph.
"""
# remove long edges (> cutoff), for sn related edges only
sn_indxes = isin(
data.edge_index, tensor(sn_indxes, device=data.edge_index.device)
).any(dim=0)
cutoff_mask = torch.logical_not((data.distances > cutoff) * sn_indxes)
data.edge_index = data.edge_index[:, cutoff_mask]
data.cell_offsets = data.cell_offsets[cutoff_mask, :]
data.distances = data.distances[cutoff_mask]
_, data.neighbors = torch.unique(
data.batch[data.edge_index[0, :]], return_counts=True
)
return data