phast.embedding#
A Python module that endows graph neural networks with physical priors as part of the embeddings of atoms from their characteristic number.
This package contains the implementation of a set of classes that are used to create atomic embeddings from physical properties of periodic table elements.
The physical embeddings are learned or kept fixed depending on the specific use-case. The embeddings can also include information regarding the group and period of the elements.
In the context of the Open Catalyst datasets, tag embeddings can also be used.
This implementation relies on Mendeleev package to access the physical properties of elements from the periodic table.
import torch
from phast.embedding import PhysEmbedding
z = torch.randint(1, 85, (3, 12)) # batch of 3 graphs with 12 atoms each
phys_embedding = PhysEmbedding(
z_emb_size=32, # default
period_emb_size=32, # default
group_emb_size=32, # default
properties_proj_size=32, # default is 0 -> no learned projection
n_elements=85, # default
)
h = phys_embedding(z) # h.shape = (3, 12, 128)
tags = torch.randint(0, 3, (3, 12))
phys_embedding = PhysEmbedding(
tag_emb_size=32, # default is 0, this is OC20-specific
final_proj_size=64, # default is 0, no projection, just the concat. of embeds.
)
h = phys_embedding(z, tags) # h.shape = (3, 12, 64)
# Assuming torch_geometric is installed:
data = torch.load("examples/data/is2re_bs3.pt")
h = phys_embedding(data.atomic_numbers.long(), data.tags) # h.shape = (261, 64)
Classes#
This module embeds inputs for use in a neural network, using both |
|
This class implements an interface to access physical properties, period and |
|
A class for retrieving physical properties from atomic numbers. |
- class phast.embedding.PhysEmbedding(z_emb_size=32, tag_emb_size=0, period_emb_size=32, group_emb_size=32, properties=PhysRef.default_properties, properties_grad=False, properties_proj_size=0, final_proj_size=0, n_elements=85)[source]#
Bases:
torch.nn.ModuleThis module embeds inputs for use in a neural network, using both standard embeddings and physical properties. The input to the embedding module can be a set of compositions, atomic numbers and tags, in addition to any extra physical properties specified.
You can disable embeddings by setting their size to 0.
- Parameters
z_emb_size (
int) – Size of the embedding for atomic number.tag_emb_size (
int) – Size of the embedding for tags.period_emb_size (
int) – Size of the embedding for periods.group_emb_size (
int) – Size of the embedding for groups.properties (
list) – List of the physical properties to include in the embedding. Each property is specified as a string, and should correspond to a valid attribute of the Pymatgen Composition class.properties_proj_size (
int) – Projection size of the physical properties embedding.properties_grad (
bool) – Whether to set the physical properties to be trainable or not.final_proj_size (
int) – Projection size for the final embedding.n_elements (
int) – Number of elements in the periodic table.
- Raises
ValueError – if self.properties_proj_size is greater than 0 and self.properties is empty
ValueError – if self.full_emb_size is 0, i.e. all sizes were set to 0.
- z_emb_size#
Size of the embedding for atomic number.
- Type
int
- tag_emb_size#
Size of the embedding for tags.
- Type
int
- period_emb_size#
Size of the embedding for periods.
- Type
int
- group_emb_size#
Size of the embedding for groups.
- Type
int
- properties#
List of the physical properties to include in the embedding. Each property must be a string as per the
elementsorfetch_ionization_energiesMendeleev tables.- Type
list
- properties_grad#
Whether to set the physical properties to be trainable or not.
- Type
bool
- n_elements#
Number of elements in the periodic table to consider.
- Type
int
- full_emb_size#
Total size of the concatenated embeddings.
- Type
int
- final_emb_size#
Output size: either the final_proj_size or full_emb_size.
- Type
int
- embeddings#
Dictionary containing the different embeddings.
- Type
nn.ModuleDict
- phys_lin#
A linear layer to project the physical properties to the given size, if projection is requested.
- Type
nn.Linear
- final_proj#
A linear layer to project the final embedding to the requested size.
- Type
nn.Linear
- forward(z, tag=None)[source]#
Embeds the input(s) using the available embeddings. Final embedding size is the sum of the individual embedding sizes, except if final_proj_size is provided, in which case the final embedding is projected to the requested size with an unbiased linear layer.
- Parameters
z (torch.Tensor) – Tensor of (long) atomic numbers.
tag (Optional[torch.Tensor]) – Open Catalyst Project-style tags. Defaults to None.
- Returns
Embedded representation of the input(s).
- Return type
torch.Tensor
- class phast.embedding.PhysRef(properties=[], period=True, group=True, short=False, n_elements=85)[source]#
Bases:
torch.nn.ModuleThis class implements an interface to access physical properties, period and group ids of elements from the periodic table.
- Parameters
properties (list) –
period (bool) –
group (bool) –
short (bool) –
n_elements (int) –
- properties_list#
A list of the properties that are actually used for creating the embeddings.
- Type
list
- n_groups#
The number of groups of the elements.
- Type
int
- n_periods#
The number of periods of the elements.
- Type
int
- n_properties#
The number of properties of the elements that are used to create the embeddings.
- Type
int
- properties#
Whether to create an embedding of physical embeddings.
- Type
bool
- properties_grad#
Whether the physical properties embedding should be learned or kept fixed.
- Type
bool
- period#
Whether to use period embeddings.
- Type
bool
- group#
Whether to use group embeddings.
- Type
bool
- short#
A boolean flag indicating whether to keep only the columns that do not have NaN values.
- Type
bool
- group_mapping#
A tensor containing the mapping from the element atomic number to the corresponding group embedding.
- Type
torch.Tensor
- period_mapping#
A tensor containing the mapping from the element atomic number to the corresponding period embedding.
- Type
torch.Tensor
- properties_mapping#
A tensor containing the mapping from the element atomic number to the corresponding physical properties embedding.
- Type
torch.Tensor
- __init__()[source]#
Initializes the PhysRef class.
- Parameters
properties (list) –
period (bool) –
group (bool) –
short (bool) –
n_elements (int) –
- Return type
None
- class phast.embedding.PropertiesEmbedding(properties, grad=False)[source]#
Bases:
torch.nn.ModuleA class for retrieving physical properties from atomic numbers.
- Parameters
properties (
torch.Tensor) – A tensor containing the properties to be embedded.grad (
bool) – Whether to enable gradient computation or not.
- properties#
A parameter or buffer storing the properties.
- Type
nn.Parameter or nn.Buffer
- forward(z)[source]#
Returns the embedded properties at the specified indices.
- Parameters
z (torch.Tensor) –