phast.graph_rewiring#

In the context of the OC20 dataset, rewire each 3D molecular graph according to 1 of 3 strategies: remove all tag-0 atoms, aggregate all tag-0 atoms into a single super-node, or aggregate all tag-0 atoms of a given element into a single super-node (hence, up to 3 super nodes will be created since 0C20 catalysts can have up to 3 elements).

graph rewiring
from phast.graph_rewiring import remove_tag0_nodes

data = load_oc20_data_batch() # Yours to define
rewired_data = remove_tag0_nodes(data)

Warning

This modules expects torch_geometric to be installed.

Functions#

adjust_cutoff_distances(data, sn_indxes[, cutoff])

Because of rewiring, some edges could be now longer than

one_supernode_per_atom_type(data[, cutoff])

For each graph independently, replace all tag-0 atoms of a given element by a new

one_supernode_per_graph(data[, cutoff, num_elements])

Replaces all tag-0 atom with a single super-node $S$ representing them, per graph.

remove_tag0_nodes(data)

Delete sub-surface (data.tag == 0) nodes and rewire the graph accordingly.

phast.graph_rewiring.adjust_cutoff_distances(data, sn_indxes, cutoff=6.0)[source]#

Because of rewiring, some edges could be now longer than the allowed cutoff distance. This function removes them.

Modified attributes: * edge_index * cell_offsets * distances * neighbors

Warning

This function modifies the input data in-place.

Parameters
  • data (torch_geometric.Data) – The rewired graph data.

  • sn_indxes (torch.Tensor[torch.Long]) – Indices of the supernodes.

  • cutoff (float, optional) – Maximum edge length. Defaults to 6.0.

Returns

The updated graph.

Return type

torch_geometric.Data

phast.graph_rewiring.one_supernode_per_atom_type(data, cutoff=6.0)[source]#

For each graph independently, replace all tag-0 atoms of a given element by a new super node \(S_i, \ i \in \{1..3\}\). As per one_supernode_per_graph(), each \(S_i\) is positioned at the center of mass of the atoms it replaces in \(x\) and \(y\) dimensions but at the maximum height of the atoms it replaces in the \(z\) dimension.

Expected data attributes are the same as for remove_tag0_nodes().

Note

\(S_i\) conserves the atomic number of the tag-0 atoms it replaces.

Warning

This function modifies the input data in-place.

Parameters
  • data (torch_geometric.Data) – the data batch to re-wire

  • cutoff (float) –

Returns

the data rewired data batch

Return type

torch_geometric.Data

phast.graph_rewiring.one_supernode_per_graph(data, cutoff=6.0, num_elements=83)[source]#

Replaces all tag-0 atom with a single super-node \(S\) representing them, per graph. For each graph, \(S\) is the last node in the graph. \(S\) is positioned at the center of mass of all tag-0 atoms in \(x\) and \(y\) directions but at the maximum \(z\) coordinate of all tag-0 atoms. All atoms previously connected to a tag-0 atom are now connected to \(S\) unless that would create an edge longer than cutoff.

Expected data attributes are the same as for remove_tag0_nodes().

Note

\(S\) will be created with a new atomic number \(Z_{S} = num\_elements + 1\), so this should be set to the number of elements expected to be present in the dataset, not that of the current graph.

Warning

This function modifies the input data in-place.

Parameters
  • data (data.Data) – single batch of graphs

  • cutoff (float) –

  • num_elements (int) –

Return type

Union[torch_geometric.data.Batch, torch_geometric.data.Data]

phast.graph_rewiring.remove_tag0_nodes(data)[source]#

Delete sub-surface (data.tag == 0) nodes and rewire the graph accordingly.

Warning

This function modifies the input data in-place.

Expected data tensor attributes:
  • pos: node positions

  • atomic_numbers: atomic numbers

  • batch: mini-batch id for each atom

  • tags: atom tags

  • edge_index: edge indices as a \(2 imes E\) tensor

  • force: force vectors per atom (optional)

  • pos_relaxed: relaxed atom positions (optional)

  • fixed: mask for fixed atoms (optional)

  • natoms: number of atoms per graph

  • ptr: cumulative sum of natoms

  • cell_offsets: unit cell directional offset for each edge

  • distances: distance between each edge’s atoms

Parameters

data (torch_geometric.Data) – the data batch to re-wire

Return type

Union[torch_geometric.data.Batch, torch_geometric.data.Data]