Source code for phast.embedding

"""
A Python module that endows graph neural networks with physical priors
as part of the embeddings of atoms from their characteristic number.

This package contains the implementation of a set of classes that are used to
create atomic embeddings from physical properties of periodic table elements.

The physical embeddings are learned or kept fixed depending on the specific use-case.
The embeddings can also include information regarding the group and period of the
elements.

In the context of the Open Catalyst datasets, tag embeddings can also be used.

This implementation relies on
`Mendeleev <https://mendeleev.readthedocs.io/en/stable/data.html>`_ package to access
the physical properties of elements from the periodic table.

.. image:: https://raw.githubusercontent.com/vict0rsch/phast/main/examples/data/embedding.png
    :alt: graph rewiring
    :width: 600px

.. code-block:: python

    import torch
    from phast.embedding import PhysEmbedding

    z = torch.randint(1, 85, (3, 12)) # batch of 3 graphs with 12 atoms each
    phys_embedding = PhysEmbedding(
        z_emb_size=32, # default
        period_emb_size=32, # default
        group_emb_size=32, # default
        properties_proj_size=32, # default is 0 -> no learned projection
        n_elements=85, # default
    )
    h = phys_embedding(z) # h.shape = (3, 12, 128)

    tags = torch.randint(0, 3, (3, 12))
    phys_embedding = PhysEmbedding(
        tag_emb_size=32, # default is 0, this is OC20-specific
        final_proj_size=64, # default is 0, no projection, just the concat. of embeds.
    )

    h = phys_embedding(z, tags) # h.shape = (3, 12, 64)

    # Assuming torch_geometric is installed:

    data = torch.load("examples/data/is2re_bs3.pt")
    h = phys_embedding(data.atomic_numbers.long(), data.tags) # h.shape = (261, 64)

"""

import os
from typing import Optional
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import Embedding, Linear

os.environ["SQLALCHEMY_SILENCE_UBER_WARNING"] = "1"
from mendeleev.fetch import fetch_ionization_energies, fetch_table


[docs]class PhysRef(nn.Module): """ This class implements an interface to access physical properties, period and group ids of elements from the periodic table. Attributes: default_properties (:obj:`list`): A list of the default properties part of atom embeddings. properties_list (:obj:`list`): A list of the properties that are actually used for creating the embeddings. n_groups (:obj:`int`): The number of groups of the elements. n_periods (:obj:`int`): The number of periods of the elements. n_properties (:obj:`int`): The number of properties of the elements that are used to create the embeddings. properties (:obj:`bool`): Whether to create an embedding of physical embeddings. properties_grad (:obj:`bool`): Whether the physical properties embedding should be learned or kept fixed. period (:obj:`bool`): Whether to use period embeddings. group (:obj:`bool`): Whether to use group embeddings. short (:obj:`bool`): A boolean flag indicating whether to keep only the columns that do not have NaN values. group_mapping (:obj:`torch.Tensor`): A tensor containing the mapping from the element atomic number to the corresponding group embedding. period_mapping (:obj:`torch.Tensor`): A tensor containing the mapping from the element atomic number to the corresponding period embedding. properties_mapping (:obj:`torch.Tensor`): A tensor containing the mapping from the element atomic number to the corresponding physical properties embedding. Methods: __init__: Initializes the PhysRef class. __repr__: Returns a string representation of the class instance. period_and_group: Returns the period and group embeddings of the elements. """
[docs] default_properties = [ "atomic_radius", "atomic_volume", "density", "dipole_polarizability", "electron_affinity", "en_allen", "vdw_radius", "metallic_radius", "metallic_radius_c12", "covalent_radius_pyykko_double", "covalent_radius_pyykko_triple", "covalent_radius_pyykko", "IE1", "IE2",
]
[docs] def __init__( self, properties: list = [], period: bool = True, group: bool = True, short: bool = False, n_elements: int = 85, ) -> None: """ Initializes the PhysRef class. Args: properties: List of properties to include in the atom embeddings. Each property must be a string as per the ``elements`` or ``fetch_ionization_energies`` `Mendeleev tables <https://mendeleev.readthedocs.io/en/stable/notebooks/bulk_data_access.html`_.Defaults to []. period: Whether to create period mappings, from atomic number to period number. group: Whether to create group mappings, from atomic number to period number. short: A boolean flag indicating whether to keep only the columns that do not have NaN values. n_elements: Number of elements to consider. Defaults to 85. """ super().__init__() self.properties_list = [ "atomic_radius", "atomic_volume", "density", "dipole_polarizability", "electron_affinity", "en_allen", "vdw_radius", "metallic_radius", "metallic_radius_c12", "covalent_radius_pyykko_double", "covalent_radius_pyykko_triple", "covalent_radius_pyykko", "IE1", "IE2", ] self.n_groups = 0 self.n_periods = 0 self.n_properties = 0 self.properties = properties self.period = period self.group = group self.short = short # Load table with all properties of all periodic table elements df = fetch_table("elements") df = df.set_index("atomic_number") # Add ionization energy ies = fetch_ionization_energies(degree=[1, 2]) df = pd.concat([df, ies], axis=1) # Fetch group and period data if self.group: df.group_id = df.group_id.fillna(value=19.0) self.n_groups = int(df.group_id.loc[:n_elements].max() + 1) group_mapping = torch.cat( [torch.ones(1), torch.tensor(df.group_id.loc[:n_elements].values)] ).long() self.register_buffer("group_mapping", group_mapping) if self.period: self.n_periods = int(df.period.loc[:n_elements].max() + 1) period_mapping = torch.cat( [torch.ones(1), torch.tensor(df.period.loc[:n_elements].values)] ).long() self.register_buffer("period_mapping", period_mapping) if self.properties: # Create an embedding of physical properties # Select only potentially relevant elements df = df[self.properties_list] df = df.loc[:n_elements, :] # ! Normalize TODO: document this df = (df - df.mean()) / df.std() # Process 'NaN' values and remove further non-essential columns # ! Normalize TODO: document this if self.short: self.properties_list = df.columns[~df.isnull().any()].tolist() df = df[self.properties_list] else: self.properties_list = df.columns[ pd.isnull(df).sum() < int(1 / 2 * df.shape[0]) ].tolist() df = df[self.properties_list] col_missing_val = df.columns[df.isna().any()].tolist() df[col_missing_val] = df[col_missing_val].fillna( value=df[col_missing_val].mean() ) self.n_properties = len(df.columns) properties_mapping = torch.cat( [ torch.zeros(1, self.n_properties), torch.from_numpy(df.values).float(), ] ) self.register_buffer("properties_mapping", properties_mapping)
[docs] def __repr__(self): return f"PhysRef(properties={self.properties}, period={self.period}, group={self.group}, short={self.short})" # noqa: E501
[docs] def period_and_group(self, z): values = {} if self.period: values["period"] = self.period_mapping[z] if self.group: values["group"] = self.group_mapping[z] return values
[docs]class PropertiesEmbedding(nn.Module): """ A class for retrieving physical properties from atomic numbers. Args: properties (:obj:`torch.Tensor`): A tensor containing the properties to be embedded. grad (:obj:`bool`): Whether to enable gradient computation or not. Attributes: properties (nn.Parameter or nn.Buffer): A parameter or buffer storing the properties. Methods: forward(z): Returns the embedded properties at the specified indices. reset_parameters(): Does nothing in this class. """ def __init__(self, properties: torch.Tensor, grad: bool = False): """ Initializes the PropertiesEmbedding object. Args: properties: A tensor containing the properties to use as embeddings. grad: Whether properties are fixed or learned (initialized from true values then updated according the gradient). """ super().__init__() assert isinstance(properties, torch.Tensor) assert isinstance(grad, bool) if grad: self.register_parameter("properties", nn.Parameter(properties)) else: self.register_buffer("properties", properties)
[docs] def forward(self, z: torch.Tensor): """ Returns a properties for each atom in the batch according to (1-based) atomic numbers. Args: z: Tensor of atomic numbers as ``torch.Long``. Returns: The properties for each atom. """ return self.properties[z]
[docs] def reset_parameters(self): pass
[docs]class PhysEmbedding(nn.Module): """This module embeds inputs for use in a neural network, using both standard embeddings and physical properties. The input to the embedding module can be a set of compositions, atomic numbers and tags, in addition to any extra physical properties specified. You can disable embeddings by setting their size to 0. Args: z_emb_size (:obj:`int`): Size of the embedding for atomic number. tag_emb_size (:obj:`int`): Size of the embedding for tags. period_emb_size (:obj:`int`): Size of the embedding for periods. group_emb_size (:obj:`int`): Size of the embedding for groups. properties (:obj:`list`): List of the physical properties to include in the embedding. Each property is specified as a string, and should correspond to a valid attribute of the Pymatgen Composition class. properties_proj_size (:obj:`int`): Projection size of the physical properties embedding. properties_grad (:obj:`bool`): Whether to set the physical properties to be trainable or not. final_proj_size (:obj:`int`): Projection size for the final embedding. n_elements (:obj:`int`): Number of elements in the periodic table. Raises: ValueError: if `self.properties_proj_size` is greater than 0 and `self.properties` is empty ValueError: if `self.full_emb_size` is 0, i.e. all sizes were set to 0. Attributes: z_emb_size (:obj:`int`): Size of the embedding for atomic number. tag_emb_size (:obj:`int`): Size of the embedding for tags. period_emb_size (:obj:`int`): Size of the embedding for periods. group_emb_size (:obj:`int`): Size of the embedding for groups. properties (:obj:`list`): List of the physical properties to include in the embedding. Each property must be a string as per the ``elements`` or ``fetch_ionization_energies`` `Mendeleev tables <https://mendeleev.readthedocs.io/en/stable/notebooks/bulk_data_access.html>`_. properties_grad (:obj:`bool`): Whether to set the physical properties to be trainable or not. n_elements (:obj:`int`): Number of elements in the periodic table to consider. phys_ref (PhysRef): Reference physical information interface. full_emb_size (:obj:`int`): Total size of the concatenated embeddings. final_emb_size (:obj:`int`): Output size: either the final_proj_size or full_emb_size. embeddings (:obj:`nn.ModuleDict`): Dictionary containing the different embeddings. phys_lin (:obj:`nn.Linear`): A linear layer to project the physical properties to the given size, if projection is requested. final_proj (:obj:`nn.Linear`): A linear layer to project the final embedding to the requested size. """ def __init__( self, z_emb_size: int = 32, tag_emb_size: int = 0, period_emb_size: int = 32, group_emb_size: int = 32, properties=PhysRef.default_properties, properties_grad: bool = False, properties_proj_size: int = 0, final_proj_size: int = 0, n_elements: int = 85, ): super().__init__() self.z_emb_size = z_emb_size self.tag_emb_size = tag_emb_size self.period_emb_size = period_emb_size self.group_emb_size = group_emb_size self.properties_proj_size = properties_proj_size self.final_proj_size = final_proj_size self.properties = properties self.properties_grad = properties_grad self.n_elements = n_elements self.phys_lin = None self.final_proj = None # Check phys_emb_type is valid assert properties_grad in { True, False, }, f"Unknown properties_grad {properties_grad}. Allowed: True or False." if self.properties_proj_size > 0 and not self.properties: raise ValueError( "Cannot project physical properties if `self.properties` is empty." ) # Check embedding sizes are non-negative for emb_name, emb_size in { "z_emb_size": z_emb_size, "tag_emb_size": tag_emb_size, "period_emb_size": period_emb_size, "group_emb_size": group_emb_size, }.items(): assert ( emb_size >= 0 ), f"Embedding size must be non-negative, got {emb_size} for {emb_name}" self.full_emb_size = int( self.z_emb_size + self.tag_emb_size + self.period_emb_size + self.group_emb_size + self.properties_proj_size * int(bool(self.properties)) ) # Physical properties self.phys_ref = PhysRef( properties=self.properties, period=self.period_emb_size > 0, group=self.group_emb_size > 0, n_elements=n_elements, ) self.embeddings = nn.ModuleDict() # Main embedding if self.z_emb_size > 0: self.embeddings["z"] = Embedding(n_elements, self.z_emb_size) # With projection? if self.properties: properties_embedding = PropertiesEmbedding( self.phys_ref.properties_mapping, self.properties_grad ) if self.properties_proj_size > 0: self.phys_lin = Linear(self.phys_ref.n_properties, properties_proj_size) self.embeddings["properties"] = nn.Sequential( properties_embedding, self.phys_lin ) else: self.embeddings["properties"] = properties_embedding self.full_emb_size += self.phys_ref.n_properties if self.full_emb_size == 0: raise ValueError("Total embedding size is 0!") # Period embedding if self.period_emb_size > 0: self.embeddings["period"] = Embedding( self.phys_ref.n_periods, self.period_emb_size ) # Group embedding if self.group_emb_size > 0: self.embeddings["group"] = Embedding( self.phys_ref.n_groups, self.group_emb_size ) # Tag embedding if self.tag_emb_size > 0: self.embeddings["tag"] = Embedding(3, self.tag_emb_size) if self.final_proj_size > 0: self.final_proj = Linear(self.full_emb_size, self.final_proj_size) self.final_emb_size = self.final_proj_size else: self.final_emb_size = self.full_emb_size self.reset_parameters()
[docs] def reset_parameters(self): """ Resets the parameters of the linear layers, and the embeddings. """ if self.phys_lin: nn.init.xavier_uniform_(self.phys_lin.weight) for emb in self.embeddings.values(): if isinstance(emb, (nn.Sequential, PhysRef)): pass else: emb.reset_parameters()
[docs] def forward(self, z: torch.Tensor, tag: Optional[torch.Tensor] = None): """ Embeds the input(s) using the available embeddings. Final embedding size is the sum of the individual embedding sizes, except if `final_proj_size` is provided, in which case the final embedding is projected to the requested size with an unbiased linear layer. Args: z: Tensor of (long) atomic numbers. tag: Open Catalyst Project-style tags. Defaults to None. Returns: :obj:`torch.Tensor`: Embedded representation of the input(s). """ pg = self.phys_ref.period_and_group(z.long()) h = [] for e, emb in self.embeddings.items(): if e in pg: h.append(emb(pg[e])) elif e in {"z", "properties"}: h.append(emb(z)) elif e == "tag": assert tag is not None, "Tag embedding is used but no tag is provided." h.append(emb(tag)) h = torch.cat(h, dim=-1) if self.final_proj: h = self.final_proj(h) return h