- Published on
Understanding Scaled Dot-Product Attention in Neural Networks for Causal Discovery
Introduction
Causal discovery is a critical problem in fields such as healthcare and economics, requiring algorithms to uncover relationships between variables in datasets. This blog post dives into a neural network model designed for the Causal Discovery Challenge, a competition organized by ADIA Lab, which aims to advance the science of causal discovery in artificial intelligence.
The competition tasked participants with estimating the causal graph (DAG) for each provided dataset and accurately identifying the roles of nodes (variables) in the causal relationship between a specified treatment variable, , and an outcome variable, . In each dataset, is known to cause , represented as .
The datasets included a large-scale collection of 47,000 individual datasets, each containing 1,000 observations and between 3 to 10 variables. Training datasets were accompanied by their corresponding causal graphs, enabling participants to refine their models before addressing the test datasets. The evaluation focused on the accuracy of predicted DAGs, particularly in identifying roles such as "Confounder," "Mediator," or "Cause of ," within the causal structure.
Model Architecture
The model takes a dataset of observations with up to 10 variables and predicts an adjacency matrix representing causal relationships. The architecture is inspired by Transformer models and employs two layers of scaled dot-product attention to capture asymmetry in causal relationships.
Key Components
- Input Layer: Transforms raw input data into latent representations.
- Two Layers of Scaled Dot-Product Attention: A core mechanism for capturing interactions between variables.
- Layer Normalization: Ensures stability during training.
- Final Layer: Outputs a masked matrix of probabilities, post-processed into a Directed Acyclic Graph (DAG).
Handling Variable Size Inputs
The network handles datasets with varying numbers of variables using a masking mechanism, ensuring consistent representation without losing critical information.
Scaled Dot-Product Attention
The scaled dot-product attention is applied in two layers to compute directional relationships between variables. Here's how it works:
Query and Key Tensors: Input features are transformed into
query
andkey
tensors through a linear transformation and ReLU activation in the input layer.First Attention Layer: The scaled dot-product attention is computed as:
x = torch.einsum('b s i d, b s j d -> b i j d', q, k) * (x.shape[1] ** -0.5)
This captures the initial interactions between variables.
Intermediate Normalization and Attention: The output is normalized using layer normalization and passed through another attention layer:
q, k = self.middle(x).chunk(2, dim=-1) x = torch.einsum('b s i d, b s j d -> b i j d', q, k) * (x.shape[1] ** -0.5)
This second attention layer refines the relationships learned in the first layer.
Scaling: The results are scaled by the square root of the number of observations for stability.
Dropout and Normalization: Dropout and layer normalization are applied after each attention layer to reduce overfitting and ensure consistent gradients.
Layer Normalization
The layer normalization layers (layer_norm1
and layer_norm2
) are pivotal in achieving the high multi-balanced accuracy of 47.986%. They stabilize the training process by normalizing activations across the feature dimension, mitigating exploding or vanishing gradients.
Comparative Results
The performance of the Transformer-based neural network was compared against several baselines. Here's how it stands out:
Model | Accuracy |
---|---|
PC baseline | 37.640% |
PC+GES | 39.898% |
RandomForest baseline | 39.515% |
NN Baseline | 37.499% |
Transformer-based NN (2-layer attention) | 47.986% |
The Transformer-based neural network significantly outperformed traditional and baseline neural network approaches, thanks to its attention mechanism and normalization techniques.
Training and Results
The model uses Binary Cross-Entropy (BCE) loss with weighted classes to address class imbalance. Training progresses with the Adam optimizer and a learning rate scheduler. After 30 epochs, the model achieves superior performance compared to other approaches.
Conclusion
This Transformer-based neural network architecture demonstrates the power of combining two layers of scaled dot-product attention with layer normalization for causal discovery. By embedding directional relationships directly into the model and stabilizing training with normalization, it achieves remarkable accuracy, setting a new benchmark in this domain.
The competition, organized by ADIA Lab, provided an excellent platform for participants to advance causal discovery methodologies. The large-scale datasets and challenging evaluation criteria underscored the need for innovative approaches, and this model exemplifies how deep learning can rise to meet such challenges.