phast.graph_rewiring#
In the context of the OC20 dataset, rewire each 3D molecular graph according to 1 of 3 strategies: remove all tag-0 atoms, aggregate all tag-0 atoms into a single super-node, or aggregate all tag-0 atoms of a given element into a single super-node (hence, up to 3 super nodes will be created since 0C20 catalysts can have up to 3 elements).
from phast.graph_rewiring import remove_tag0_nodes
data = load_oc20_data_batch() # Yours to define
rewired_data = remove_tag0_nodes(data)
Warning
This modules expects torch_geometric to be installed.
Functions#
|
Because of rewiring, some edges could be now longer than |
|
For each graph independently, replace all tag-0 atoms of a given element by a new |
|
Replaces all tag-0 atom with a single super-node $S$ representing them, per graph. |
|
Delete sub-surface ( |
- phast.graph_rewiring.adjust_cutoff_distances(data, sn_indxes, cutoff=6.0)[source]#
Because of rewiring, some edges could be now longer than the allowed cutoff distance. This function removes them.
Modified attributes: *
edge_index*cell_offsets*distances*neighborsWarning
This function modifies the input data in-place.
- Parameters
data (torch_geometric.Data) – The rewired graph data.
sn_indxes (torch.Tensor[torch.Long]) – Indices of the supernodes.
cutoff (float, optional) – Maximum edge length. Defaults to 6.0.
- Returns
The updated graph.
- Return type
torch_geometric.Data
- phast.graph_rewiring.one_supernode_per_atom_type(data, cutoff=6.0)[source]#
For each graph independently, replace all tag-0 atoms of a given element by a new super node \(S_i, \ i \in \{1..3\}\). As per
one_supernode_per_graph(), each \(S_i\) is positioned at the center of mass of the atoms it replaces in \(x\) and \(y\) dimensions but at the maximum height of the atoms it replaces in the \(z\) dimension.Expected
dataattributes are the same as forremove_tag0_nodes().Note
\(S_i\) conserves the atomic number of the tag-0 atoms it replaces.
Warning
This function modifies the input data in-place.
- Parameters
data (torch_geometric.Data) – the data batch to re-wire
cutoff (float) –
- Returns
the data rewired data batch
- Return type
torch_geometric.Data
- phast.graph_rewiring.one_supernode_per_graph(data, cutoff=6.0, num_elements=83)[source]#
Replaces all tag-0 atom with a single super-node \(S\) representing them, per graph. For each graph, \(S\) is the last node in the graph. \(S\) is positioned at the center of mass of all tag-0 atoms in \(x\) and \(y\) directions but at the maximum \(z\) coordinate of all tag-0 atoms. All atoms previously connected to a tag-0 atom are now connected to \(S\) unless that would create an edge longer than
cutoff.Expected
dataattributes are the same as forremove_tag0_nodes().Note
\(S\) will be created with a new atomic number \(Z_{S} = num\_elements + 1\), so this should be set to the number of elements expected to be present in the dataset, not that of the current graph.
Warning
This function modifies the input data in-place.
- Parameters
data (data.Data) – single batch of graphs
cutoff (float) –
num_elements (int) –
- Return type
Union[torch_geometric.data.Batch, torch_geometric.data.Data]
- phast.graph_rewiring.remove_tag0_nodes(data)[source]#
Delete sub-surface (
data.tag == 0) nodes and rewire the graph accordingly.Warning
This function modifies the input data in-place.
- Expected
datatensor attributes: pos: node positionsatomic_numbers: atomic numbersbatch: mini-batch id for each atomtags: atom tagsedge_index: edge indices as a \(2 imes E\) tensorforce: force vectors per atom (optional)pos_relaxed: relaxed atom positions (optional)fixed: mask for fixed atoms (optional)natoms: number of atoms per graphptr: cumulative sum ofnatomscell_offsets: unit cell directional offset for each edgedistances: distance between each edge’s atoms
- Parameters
data (torch_geometric.Data) – the data batch to re-wire
- Return type
Union[torch_geometric.data.Batch, torch_geometric.data.Data]
- Expected