Skip to content

gnm.utils

utils

Utility functions for working with generative network models.

This subpackage provides various utility functions that support the core generative network modeling functionality. It includes:

  • Statistical measures: Functions for statistical comparisons between networks
  • Graph properties: Various network metrics and measures for analyzing graph structure
  • Data validation: Functions to verify the validity of network data structures
  • Control networks: Functions for generating control networks with preserved properties
  • Convert Datatypes: Functions to convert numpy and additional datatypes to GNM-compatible tensor

These utilities handle both binary and weighted networks and are optimised for use with PyTorch tensors.

Controls

gnm.utils.get_control(matrices)

Generate control networks by randomly permuting connections while preserving network properties.

This function creates randomized versions of the input networks while maintaining: - The same number of connections (for binary networks) or weight distribution (for weighted networks) - Symmetry (undirected graph structure) - No self-connections (zeros on diagonal)

Parameters:

Name Type Description Default
matrices Float[Tensor, 'num_networks num_nodes num_nodes']

Input adjacency or weight matrices with shape [num_networks, num_nodes, num_nodes]

required

Returns:

Type Description
Float[Tensor, 'num_networks num_nodes num_nodes']

Permuted control networks with the same shape as input matrices, preserving key properties

Examples:

>>> import torch
>>> from gnm.utils import get_control
>>> from gnm.defaults import get_binary_network
>>> # Get a real network
>>> real_network = get_binary_network()
>>> # Generate a control with preserved properties
>>> control_network = get_control(real_network)
>>> # Check that control has same number of connections
>>> real_network.sum() == control_network.sum()
tensor(True)
Notes
  • For binary networks, this is equivalent to randomly rewiring all connections
  • For weighted networks, connection weights are preserved but redistributed
Source code in src/gnm/utils/control.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@jaxtyped(typechecker=typechecked)
def get_control(
    matrices: Float[torch.Tensor, "num_networks num_nodes num_nodes"]
) -> Float[torch.Tensor, "num_networks num_nodes num_nodes"]:
    """Generate control networks by randomly permuting connections while preserving network properties.

    This function creates randomized versions of the input networks while maintaining:
    - The same number of connections (for binary networks) or weight distribution (for weighted networks)
    - Symmetry (undirected graph structure)
    - No self-connections (zeros on diagonal)

    Args:
        matrices:
            Input adjacency or weight matrices with shape [num_networks, num_nodes, num_nodes]

    Returns:
        Permuted control networks with the same shape as input matrices, preserving key properties

    Examples:
        >>> import torch
        >>> from gnm.utils import get_control
        >>> from gnm.defaults import get_binary_network
        >>> # Get a real network
        >>> real_network = get_binary_network()
        >>> # Generate a control with preserved properties
        >>> control_network = get_control(real_network)
        >>> # Check that control has same number of connections
        >>> real_network.sum() == control_network.sum()
        tensor(True)

    Notes:
        - For binary networks, this is equivalent to randomly rewiring all connections
        - For weighted networks, connection weights are preserved but redistributed
    """
    num_networks, num_nodes, _ = matrices.shape
    control_networks = torch.zeros_like(matrices)

    # Process each network in the batch
    for i in range(num_networks):
        network = matrices[i]

        # Get upper triangular indices (excluding diagonal)
        indices = torch.triu_indices(num_nodes, num_nodes, offset=1)
        upper_values = network[indices[0], indices[1]]

        # Permute the upper triangular values
        perm_idx = torch.randperm(indices.shape[1])
        permuted_indices = (indices[0, perm_idx], indices[1, perm_idx])

        # Create a new network with permuted connections
        control = torch.zeros_like(network)
        control[permuted_indices[0], permuted_indices[1]] = upper_values

        # Ensure symmetry
        control = control + control.T
        control_networks[i] = control

    return control_networks

Statistics

gnm.utils.ks_statistic(samples_1, samples_2)

Compute Kolmogorov-Smirnov statistics between all pairs of distributions in two batches.

The Kolmogorov-Smirnov (KS) statistic measures the maximum absolute difference between two cumulative distribution functions. This function efficiently computes KS statistics for all pairs of distributions between two batches of samples, which is useful for comparing multiple generated networks with observed networks.

Parameters:

Name Type Description Default
samples_1 Float[Tensor, 'batch_1 num_samples_1']

First batch of samples with shape [batch_1, num_samples_1]

required
samples_2 Float[Tensor, 'batch_2 num_samples_2']

Second batch of samples with shape [batch_2, num_samples_2]

required

Returns:

Type Description
Float[Tensor, 'batch_1 batch_2']

KS statistics for all pairs with shape [batch_1, batch_2]

Examples:

>>> import torch
>>> from gnm.utils import ks_statistic
>>> # Create two batches of samples
>>> samples_1 = torch.randn(3, 100)  # 3 distributions, 100 samples each
>>> samples_2 = torch.randn(2, 150)  # 2 distributions, 150 samples each
>>> ks_stats = ks_statistic(samples_1, samples_2)
>>> ks_stats.shape
torch.Size([3, 2])
>>> # Each entry ks_stats[i,j] is the KS statistic between
>>> # the i-th distribution from batch 1 and j-th distribution from batch 2
See Also
Source code in src/gnm/utils/statistics.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@jaxtyped(typechecker=typechecked)
def ks_statistic(
    samples_1: Float[torch.Tensor, "batch_1 num_samples_1"],
    samples_2: Float[torch.Tensor, "batch_2 num_samples_2"],
) -> Float[torch.Tensor, "batch_1 batch_2"]:
    r"""Compute Kolmogorov-Smirnov statistics between all pairs of distributions in two batches.

    The Kolmogorov-Smirnov (KS) statistic measures the maximum absolute difference
    between two cumulative distribution functions. This function efficiently computes
    KS statistics for all pairs of distributions between two batches of samples, which
    is useful for comparing multiple generated networks with observed networks.

    Args:
        samples_1:
            First batch of samples with shape [batch_1, num_samples_1]
        samples_2:
            Second batch of samples with shape [batch_2, num_samples_2]

    Returns:
        KS statistics for all pairs with shape [batch_1, batch_2]

    Examples:
        >>> import torch
        >>> from gnm.utils import ks_statistic
        >>> # Create two batches of samples
        >>> samples_1 = torch.randn(3, 100)  # 3 distributions, 100 samples each
        >>> samples_2 = torch.randn(2, 150)  # 2 distributions, 150 samples each
        >>> ks_stats = ks_statistic(samples_1, samples_2)
        >>> ks_stats.shape
        torch.Size([3, 2])
        >>> # Each entry ks_stats[i,j] is the KS statistic between
        >>> # the i-th distribution from batch 1 and j-th distribution from batch 2

    See Also:
        - [`evaluation.KSCriterion`][gnm.evaluation.KSCriterion]: Uses KS statistics to compute discrepancy between networks measure distributions
    """
    # Sort samples for CDF computation
    sorted_1, _ = torch.sort(samples_1, dim=1)  # [batch_1, n_samples_1]
    sorted_2, _ = torch.sort(samples_2, dim=1)  # [batch_2, n_samples_2]

    # Get all unique values that could be CDF evaluation points
    # Combine all samples and get unique sorted values
    all_values = torch.unique(
        torch.cat([sorted_1.reshape(-1), sorted_2.reshape(-1)])
    )  # [n_unique]

    # Compute CDFs for all distributions at these points
    # For each batch, count fraction of samples less than each value
    cdf_1 = (
        (sorted_1.unsqueeze(-1) <= all_values.unsqueeze(0).unsqueeze(0))
        .float()
        .mean(dim=1)
    )

    cdf_2 = (
        (sorted_2.unsqueeze(-1) <= all_values.unsqueeze(0).unsqueeze(0))
        .float()
        .mean(dim=1)
    )

    # Compute absolute differences between all pairs of CDFs
    # Use broadcasting to compute differences between all pairs in the batches
    differences = torch.abs(
        cdf_1.unsqueeze(1) - cdf_2.unsqueeze(0)
    )  # [batch_1, batch_2, n_unique]

    # Get maximum difference for each pair
    ks_statistics = torch.max(differences, dim=2).values  # [batch_1, batch_2]

    return ks_statistics

Checks

gnm.utils.binary_checks(matrices)

Check that matrices satisfy binary network constraints.

Validates that the provided adjacency matrices conform to the expected properties for binary networks:

  1. All values are either 0 or 1 (matrices are binary)
  2. Matrices are symmetric (undirected)
  3. No self-connections (zeros on the diagonal)

Parameters:

Name Type Description Default
matrices Float[Tensor, 'num_networks num_nodes num_nodes']

Adjacency matrices to check with shape [num_networks, num_nodes, num_nodes]

required

Raises:

Type Description
AssertionError

If any of the conditions are not met, with a descriptive error message

Examples:

>>> import torch
>>> from gnm.utils import binary_checks
>>> # Create a valid binary network
>>> valid_network = torch.zeros((1, 3, 3))
>>> valid_network[0, 0, 1] = 1
>>> valid_network[0, 1, 0] = 1
>>> binary_checks(valid_network)  # No error
>>>
>>> # Invalid binary network with non-binary values
>>> non_binary_network = torch.zeros((1, 3, 3))
>>> non_binary_network[0, 0, 1] = 0.5
>>> non_binary_network[0, 1, 0] = 0.5
>>> binary_checks(non_binary_network)  # Raises AssertionError: "Matrices must be binary"
>>>
>>> # Invalid binary network which is not symmetric
>>> non_symmetric_network = torch.zeros((1, 3, 3))
>>> non_symmetric_network[0, 0, 1] = 1
>>> binary_checks(non_symmetric_network)  # Raises AssertionError: "Matrices must be symmetric"
See Also
Source code in src/gnm/utils/checks.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@jaxtyped(typechecker=typechecked)
def binary_checks(matrices: Float[torch.Tensor, "num_networks num_nodes num_nodes"]):
    r"""Check that matrices satisfy binary network constraints.

    Validates that the provided adjacency matrices conform to the expected properties
    for binary networks:

    1. All values are either 0 or 1 (matrices are binary)
    2. Matrices are symmetric (undirected)
    3. No self-connections (zeros on the diagonal)

    Args:
        matrices:
            Adjacency matrices to check with shape [num_networks, num_nodes, num_nodes]

    Raises:
        AssertionError: If any of the conditions are not met, with a descriptive error message

    Examples:
        >>> import torch
        >>> from gnm.utils import binary_checks
        >>> # Create a valid binary network
        >>> valid_network = torch.zeros((1, 3, 3))
        >>> valid_network[0, 0, 1] = 1
        >>> valid_network[0, 1, 0] = 1
        >>> binary_checks(valid_network)  # No error
        >>>
        >>> # Invalid binary network with non-binary values
        >>> non_binary_network = torch.zeros((1, 3, 3))
        >>> non_binary_network[0, 0, 1] = 0.5
        >>> non_binary_network[0, 1, 0] = 0.5
        >>> binary_checks(non_binary_network)  # Raises AssertionError: "Matrices must be binary"
        >>>
        >>> # Invalid binary network which is not symmetric
        >>> non_symmetric_network = torch.zeros((1, 3, 3))
        >>> non_symmetric_network[0, 0, 1] = 1
        >>> binary_checks(non_symmetric_network)  # Raises AssertionError: "Matrices must be symmetric"

    See Also:
        - [`utils.weighted_checks`][gnm.utils.weighted_checks]: For validating weighted networks
        - [`defaults.get_binary_network`][gnm.defaults.get_binary_network]: For loading pre-validated binary networks
    """
    # Check that the matrices are binary:
    assert torch.all((matrices == 0) | (matrices == 1)), "Matrices must be binary"
    # Check that the matrices are symmetric:
    assert torch.allclose(
        matrices, matrices.transpose(-1, -2)
    ), "Matrices must be symmetric"
    # Check that the matrices are not self-connected:
    assert torch.all(
        matrices.diagonal(dim1=-2, dim2=-1) == 0
    ), "Matrices must not be self-connected"

gnm.utils.weighted_checks(matrices)

Check that matrices satisfy weighted network constraints.

Validates that the provided weight matrices conform to the expected properties for weighted networks:

  1. All values are non-negative
  2. Matrices are symmetric (undirected)
  3. No self-connections (zeros on the diagonal)

Parameters:

Name Type Description Default
matrices Float[Tensor, 'num_networks num_nodes num_nodes']

Weight matrices to check with shape [num_networks, num_nodes, num_nodes]

required

Raises:

Type Description
AssertionError

If any of the conditions are not met, with a descriptive error message

Examples:

>>> import torch
>>> from gnm.utils import weighted_checks
>>> # Create a valid weighted network
>>> valid_network = torch.zeros((1, 3, 3))
>>> valid_network[0, 0, 1] = 0.5
>>> valid_network[0, 1, 0] = 0.5
>>> weighted_checks(valid_network)  # No error
>>>
>>> # Invalid weighted network with negative values
>>> negative_network = torch.zeros((1, 3, 3))
>>> negative_network[0, 0, 1] = -0.5
>>> negative_network[0, 1, 0] = -0.5
>>> weighted_checks(negative_network)  # Raises AssertionError: "Matrices must be non-negative"
>>>
>>> # Invalid weighted network which is self-connected
>>> self_connected_network = torch.zeros((1, 3, 3))
>>> self_connected_network[0, 0, 0] = 1
>>> weighted_checks(self_connected_network)  # Raises AssertionError: "Matrices must not be self-connected"
See Also
Source code in src/gnm/utils/checks.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@jaxtyped(typechecker=typechecked)
def weighted_checks(matrices: Float[torch.Tensor, "num_networks num_nodes num_nodes"]):
    r"""Check that matrices satisfy weighted network constraints.

    Validates that the provided weight matrices conform to the expected properties
    for weighted networks:

    1. All values are non-negative
    2. Matrices are symmetric (undirected)
    3. No self-connections (zeros on the diagonal)

    Args:
        matrices:
            Weight matrices to check with shape [num_networks, num_nodes, num_nodes]

    Raises:
        AssertionError: If any of the conditions are not met, with a descriptive error message

    Examples:
        >>> import torch
        >>> from gnm.utils import weighted_checks
        >>> # Create a valid weighted network
        >>> valid_network = torch.zeros((1, 3, 3))
        >>> valid_network[0, 0, 1] = 0.5
        >>> valid_network[0, 1, 0] = 0.5
        >>> weighted_checks(valid_network)  # No error
        >>>
        >>> # Invalid weighted network with negative values
        >>> negative_network = torch.zeros((1, 3, 3))
        >>> negative_network[0, 0, 1] = -0.5
        >>> negative_network[0, 1, 0] = -0.5
        >>> weighted_checks(negative_network)  # Raises AssertionError: "Matrices must be non-negative"
        >>>
        >>> # Invalid weighted network which is self-connected
        >>> self_connected_network = torch.zeros((1, 3, 3))
        >>> self_connected_network[0, 0, 0] = 1
        >>> weighted_checks(self_connected_network)  # Raises AssertionError: "Matrices must not be self-connected"

    See Also:
        - [`utils.binary_checks`][gnm.utils.binary_checks]: For validating binary networks
        - [`defaults.get_weighted_network`][gnm.defaults.get_weighted_network]: For loading pre-validated weighted networks
    """
    # Check that the matrices are non-negative:
    assert torch.all(matrices >= 0), "Matrices must be non-negative"
    # Check that the matrices are symmetric:
    assert torch.allclose(
        matrices, matrices.transpose(-1, -2)
    ), "Matrices must be symmetric"
    # Check that the matrices are not self-connected:
    assert torch.all(
        matrices.diagonal(dim1=-2, dim2=-1) == 0
    ), "Matrices must not be self-connected"

Graph properties

gnm.utils.node_strengths(adjacency_matrix)

Compute the node strengths (or nodal degree) for each node in the network.

For binary networks, this is equivalent to the node degree (number of connections). For weighted networks, this represents the sum of all edge weights connected to each node.

Parameters:

Name Type Description Default
adjacency_matrix Float[Tensor, '*batch num_nodes num_nodes']

Adjacency matrix (binary or weighted) with shape [*batch, num_nodes, num_nodes]

required

Returns:

Type Description
Float[Tensor, '*batch num_nodes']

Vector of node strengths for each node in the network with shape [*batch, num_nodes]

Examples:

>>> import torch
>>> from gnm.utils import node_strengths
>>> # Create a sample binary network
>>> adj_matrix = torch.zeros(1, 4, 4)
>>> adj_matrix[0, 0, 1] = 1
>>> adj_matrix[0, 1, 0] = 1
>>> adj_matrix[0, 1, 2] = 1
>>> adj_matrix[0, 2, 1] = 1
>>> strength = node_strengths(adj_matrix)
>>> strength
tensor([[1., 2., 1., 0.]])
See Also
  • evaluation.DegreeKS: Binary evaluation criterion which compares the distribution of node degrees between two binary networks.
  • evaluation.WeightedNodeStrengthKS: Weighted evaluation criterion which compares the distribution of node strengths between two weighted networks.
  • evaluation.DegreeCorrelation: Binary evaluation criterion which compares the correlations between the node degrees between two binary networks.
Source code in src/gnm/utils/graph_properties.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@jaxtyped(typechecker=typechecked)
def node_strengths(
    adjacency_matrix: Float[torch.Tensor, "*batch num_nodes num_nodes"]
) -> Float[torch.Tensor, "*batch num_nodes"]:
    r"""Compute the node strengths (or nodal degree) for each node in the network.

    For binary networks, this is equivalent to the node degree (number of connections).
    For weighted networks, this represents the sum of all edge weights connected to each node.

    Args:
        adjacency_matrix:
            Adjacency matrix (binary or weighted) with shape [*batch, num_nodes, num_nodes]

    Returns:
        Vector of node strengths for each node in the network with shape [*batch, num_nodes]

    Examples:
        >>> import torch
        >>> from gnm.utils import node_strengths
        >>> # Create a sample binary network
        >>> adj_matrix = torch.zeros(1, 4, 4)
        >>> adj_matrix[0, 0, 1] = 1
        >>> adj_matrix[0, 1, 0] = 1
        >>> adj_matrix[0, 1, 2] = 1
        >>> adj_matrix[0, 2, 1] = 1
        >>> strength = node_strengths(adj_matrix)
        >>> strength
        tensor([[1., 2., 1., 0.]])

    See Also:
        - [`evaluation.DegreeKS`][gnm.evaluation.DegreeKS]: Binary evaluation criterion which compares the distribution of node degrees between two binary networks.
        - [`evaluation.WeightedNodeStrengthKS`][gnm.evaluation.WeightedNodeStrengthKS]: Weighted evaluation criterion which compares the distribution of node strengths between two weighted networks.
        - [`evaluation.DegreeCorrelation`][gnm.evaluation.DegreeCorrelation]: Binary evaluation criterion which compares the correlations between the node degrees between two binary networks.
    """
    return adjacency_matrix.sum(dim=-1)

gnm.utils.binary_clustering_coefficients(adjacency_matrix)

Compute the clustering coefficients for each node in a binary network.

The clustering coefficient measures the degree to which nodes in a graph tend to cluster together. For a node i, it quantifies how close its neighbors are to being a complete subgraph (clique).

The clustering coefficient for a node \(i\) is computed as: $$ c(i) = \frac{ 2t_i }{ k_i (k_i - 1) }, $$ where \(t_i\) is the number of (unordered) triangles around node \(i\), and \(k_i\) is the degree of node \(i\).

Parameters:

Name Type Description Default
adjacency_matrix Float[Tensor, '*batch num_nodes num_nodes']

Binary adjacency matrix with shape [*batch, num_nodes, num_nodes]

required

Returns:

Type Description
Float[Tensor, '*batch num_nodes']

The clustering coefficients for each node with shape [*batch, num_nodes]

Examples:

>>> import torch
>>> from gnm.utils import binary_clustering_coefficients
>>> # Create a binary network with a triangle
>>> adj_matrix = torch.zeros(1, 4, 4)
>>> adj_matrix[0, 0, 1] = 1
>>> adj_matrix[0, 1, 0] = 1
>>> adj_matrix[0, 1, 2] = 1
>>> adj_matrix[0, 2, 1] = 1
>>> adj_matrix[0, 0, 2] = 1
>>> adj_matrix[0, 2, 0] = 1
>>> clustering = binary_clustering_coefficients(adj_matrix)
>>> clustering
tensor([[1., 1., 1., 0.]])
See Also
Source code in src/gnm/utils/graph_properties.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@jaxtyped(typechecker=typechecked)
def binary_clustering_coefficients(
    adjacency_matrix: Float[torch.Tensor, "*batch num_nodes num_nodes"]
) -> Float[torch.Tensor, "*batch num_nodes"]:
    r"""Compute the clustering coefficients for each node in a binary network.

    The clustering coefficient measures the degree to which nodes in a graph tend to cluster together.
    For a node i, it quantifies how close its neighbors are to being a complete subgraph (clique).

    The clustering coefficient for a node $i$ is computed as:
    $$
        c(i) = \\frac{ 2t_i }{ k_i (k_i - 1) },
    $$
    where $t_i$ is the number of (unordered) triangles around node $i$, and $k_i$ is the degree of node $i$.

    Args:
        adjacency_matrix:
            Binary adjacency matrix with shape [*batch, num_nodes, num_nodes]

    Returns:
        The clustering coefficients for each node with shape [*batch, num_nodes]

    Examples:
        >>> import torch
        >>> from gnm.utils import binary_clustering_coefficients
        >>> # Create a binary network with a triangle
        >>> adj_matrix = torch.zeros(1, 4, 4)
        >>> adj_matrix[0, 0, 1] = 1
        >>> adj_matrix[0, 1, 0] = 1
        >>> adj_matrix[0, 1, 2] = 1
        >>> adj_matrix[0, 2, 1] = 1
        >>> adj_matrix[0, 0, 2] = 1
        >>> adj_matrix[0, 2, 0] = 1
        >>> clustering = binary_clustering_coefficients(adj_matrix)
        >>> clustering
        tensor([[1., 1., 1., 0.]])

    See Also:
        - [`utils.weighted_clustering_coefficients`][gnm.utils.weighted_clustering_coefficients]: For calculating clustering coefficient in weighted networks.
        - [`evaluation.ClusteringKS`][gnm.evaluation.ClusteringKS]: Binary evaluation criterion which compares the distribution of clustering coefficients between two binary networks.
        - [`evaluation.ClusteringCorrelation`][gnm.evaluation.ClusteringCorrelation]: Binary evaluation criterion which compares the correlations between the clustering coefficients between two binary networks.
    """
    binary_checks(adjacency_matrix)

    degrees = adjacency_matrix.sum(dim=-1)
    number_of_pairs = degrees * (degrees - 1)

    number_of_triangles = torch.diagonal(
        torch.matmul(
            torch.matmul(adjacency_matrix, adjacency_matrix), adjacency_matrix
        ),
        dim1=-2,
        dim2=-1,
    )

    clustering = torch.zeros_like(number_of_triangles)
    mask = number_of_pairs > 0

    # removed 2 * to match BCT output
    clustering[mask] = number_of_triangles[mask] / number_of_pairs[mask]
    return clustering

gnm.utils.weighted_clustering_coefficients(weight_matrices)

Compute weighted clustering coefficients based on Onnela et al. (2005) definition.

This implementation uses the geometric mean of triangle weights. For each node \(i\), the clustering coefficient is:

\[ c(i) = \frac{1}{k_i (k_i - 1)} \sum_{jk} (\hat{w}_{ij} \times \hat{w}_{jk} \times \hat{w}_{ki})^{1/3}, \]

where \(k_i\) is the node strength of node \(i\), and \(\hat{w}_{ij}\) is the weight of the edge between nodes \(i\) and \(j\), after normalising by dividing by the maximum weight in the network.

Parameters:

Name Type Description Default
weight_matrices Float[Tensor, '*batch num_nodes num_nodes']

Batch of weighted adjacency matrices with shape [*batch, num_nodes, num_nodes]. Weights should be non-negative.

required

Returns:

Type Description
Float[Tensor, '*batch num_nodes']

Clustering coefficients for each node in each network with shape [*batch, num_nodes]

Examples:

>>> import torch
>>> from gnm.utils import weighted_clustering_coefficients
>>> # Create a weighted network with a triangle
>>> weight_matrix = torch.zeros(1, 4, 4)
>>> weight_matrix[0, 0, 1] = 0.5
>>> weight_matrix[0, 1, 0] = 0.5
>>> weight_matrix[0, 1, 2] = 0.8
>>> weight_matrix[0, 2, 1] = 0.8
>>> weight_matrix[0, 0, 2] = 0.6
>>> weight_matrix[0, 2, 0] = 0.6
>>> clustering = weighted_clustering_coefficients(weight_matrix)
>>> clustering.shape
torch.Size([1, 4])
See Also
Source code in src/gnm/utils/graph_properties.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@jaxtyped(typechecker=typechecked)
def weighted_clustering_coefficients(
    weight_matrices: Float[torch.Tensor, "*batch num_nodes num_nodes"]
) -> Float[torch.Tensor, "*batch num_nodes"]:
    r"""Compute weighted clustering coefficients based on Onnela et al. (2005) definition.

    This implementation uses the geometric mean of triangle weights. For each node $i$,
    the clustering coefficient is:

    $$
    c(i) = \frac{1}{k_i (k_i - 1)} \sum_{jk} (\hat{w}_{ij} \times \hat{w}_{jk} \times \hat{w}_{ki})^{1/3},
    $$

    where $k_i$ is the node strength of node $i$, and $\hat{w}_{ij}$ is the weight of the edge between nodes $i$ and $j$,
    *after* normalising by dividing by the maximum weight in the network.

    Args:
        weight_matrices:
            Batch of weighted adjacency matrices with shape [*batch, num_nodes, num_nodes].
            Weights should be non-negative.

    Returns:
        Clustering coefficients for each node in each network with shape [*batch, num_nodes]

    Examples:
        >>> import torch
        >>> from gnm.utils import weighted_clustering_coefficients
        >>> # Create a weighted network with a triangle
        >>> weight_matrix = torch.zeros(1, 4, 4)
        >>> weight_matrix[0, 0, 1] = 0.5
        >>> weight_matrix[0, 1, 0] = 0.5
        >>> weight_matrix[0, 1, 2] = 0.8
        >>> weight_matrix[0, 2, 1] = 0.8
        >>> weight_matrix[0, 0, 2] = 0.6
        >>> weight_matrix[0, 2, 0] = 0.6
        >>> clustering = weighted_clustering_coefficients(weight_matrix)
        >>> clustering.shape
        torch.Size([1, 4])

    See Also:
        - [`utils.binary_clustering_coefficients`][gnm.utils.binary_clustering_coefficients]: For calculating clustering in binary networks.
        - [`evaluation.WeightedClusteringKS`][gnm.evaluation.WeightedClusteringKS]: Weighted evaluation criterion which compares the distribution of (weighted) clustering coefficients between two weighted networks.
    """
    weighted_checks(weight_matrices)

    # each triange to exponent of 1/3 for cube root norm
    normalised_w = torch.pow(weight_matrices, 1/3)

    # Get max weight for normalization (keeping batch dims)
    max_weight = normalised_w.amax(dim=(-2, -1), keepdim=True)  # [*batch, 1, 1]
    normalised_w = normalised_w / max_weight 

    # For each node u, compute the geometric mean of triangle weights:
    # (w_uv * w_vw * w_wu) ^ (1/3)
    triangles = torch.diagonal(
        torch.matmul(torch.matmul(normalised_w, normalised_w), normalised_w),
        dim1=-2,
        dim2=-1,
    ) # [*batch, num_nodes]

    # Get node strengths (sum of weights)
    degree = torch.sum(weight_matrices > 0, dim=-1)  # [*batch, num_nodes]

    # Compute denominator k * (k-1) (k = degree)
    denom = degree * (degree - 1)  # [*batch, num_nodes]

    # Handle division by zero - set clustering to 0 where k <= 1
    clustering = torch.zeros_like(triangles)
    mask = denom > 0
    clustering[mask] = triangles[mask] / denom[mask]

    return clustering

gnm.utils.communicability(weight_matrix)

Compute the communicability matrix for a network.

Communicability measures the ease of information flow between nodes, taking into account all possible paths between them. It's based on the matrix exponential of the normalized adjacency matrix.

To compute the communicability matrix, we go through the following steps:

  1. Compute the diagonal node strength matrix, \(S_{ii} = \sum_j W_{ij}\) (plus a small constant to prevent division by zero).
  2. Compute the normalised weight matrix, \(S^{-1/2} W S^{-1/2}\).
  3. Compute the communicability matrix by taking the matrix exponential, \(\exp( S^{-1/2} W S^{-1/2} )\).

Parameters:

Name Type Description Default
weight_matrix Float[Tensor, '*batch num_nodes num_nodes']

Weighted adjacency matrix with shape [*batch, num_nodes, num_nodes]

required

Returns:

Type Description
Float[Tensor, '*batch num_nodes num_nodes']

Communicability matrix with shape [*batch, num_nodes, num_nodes]

Examples:

>>> import torch
>>> from gnm.utils import communicability
>>> # Create a simple weighted network
>>> weight_matrix = torch.zeros(1, 3, 3)
>>> weight_matrix[0, 0, 1] = 0.5
>>> weight_matrix[0, 1, 0] = 0.5
>>> weight_matrix[0, 1, 2] = 0.8
>>> weight_matrix[0, 2, 1] = 0.8
>>> comm_matrix = communicability(weight_matrix)
>>> comm_matrix.shape
torch.Size([1, 3, 3])
See Also
Source code in src/gnm/utils/graph_properties.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@jaxtyped(typechecker=typechecked)
def communicability(
    weight_matrix: Float[torch.Tensor, "*batch num_nodes num_nodes"]
) -> Float[torch.Tensor, "*batch num_nodes num_nodes"]:
    r"""Compute the communicability matrix for a network.

    Communicability measures the ease of information flow between nodes, taking into
    account all possible paths between them. It's based on the matrix exponential of
    the normalized adjacency matrix.

    To compute the communicability matrix, we go through the following steps:

    1. Compute the diagonal node strength matrix, $S_{ii} = \sum_j W_{ij}$ (plus a small constant to prevent division by zero).
    2. Compute the normalised weight matrix, $S^{-1/2} W S^{-1/2}$.
    3. Compute the communicability matrix by taking the matrix exponential, $\exp( S^{-1/2} W S^{-1/2} )$.

    Args:
        weight_matrix:
            Weighted adjacency matrix with shape [*batch, num_nodes, num_nodes]

    Returns:
        Communicability matrix with shape [*batch, num_nodes, num_nodes]

    Examples:
        >>> import torch
        >>> from gnm.utils import communicability
        >>> # Create a simple weighted network
        >>> weight_matrix = torch.zeros(1, 3, 3)
        >>> weight_matrix[0, 0, 1] = 0.5
        >>> weight_matrix[0, 1, 0] = 0.5
        >>> weight_matrix[0, 1, 2] = 0.8
        >>> weight_matrix[0, 2, 1] = 0.8
        >>> comm_matrix = communicability(weight_matrix)
        >>> comm_matrix.shape
        torch.Size([1, 3, 3])

    See Also:
        - [`weight_criteria.Communicability`][gnm.weight_criteria.Communicability]: weight optimisation criterion which minimises total communicability.
        - [`weight_criteria.NormalisedCommunicability`][gnm.weight_criteria.NormalisedCommunicability]: weight optimisation criterion which minimises total communicability, divided by the maximum communicability.
        - [`weight_criteria.DistanceWeightedCommunicability`][gnm.weight_criteria.DistanceWeightedCommunicability]: weight optimisation criterion which minimises total communicability, weighted by the distance between nodes.
        - [`weight_criteria.NormalisedDistanceWeightedCommunicability`][gnm.weight_criteria.NormalisedDistanceWeightedCommunicability]: weight optimisation criterion which minimises total communicability, weighted by the distance between nodes and divided by the maximum distance-weighted communicability.
    """
    # Compute the node strengths, with a small constant addition to prevent division by zero.
    node_strengths = (
        0.5 * (weight_matrix.sum(dim=-1) + weight_matrix.sum(dim=-2)) + 1e-6
    )

    # Create diagonal matrix for each batch element
    batch_shape = weight_matrix.shape[:-2]
    num_nodes = weight_matrix.shape[-1]
    inv_sqrt_node_strengths = torch.zeros(
        *batch_shape, num_nodes, num_nodes, device=weight_matrix.device
    )

    # Set diagonal values for each batch element
    diag_indices = torch.arange(num_nodes)
    inv_sqrt_node_strengths[..., diag_indices, diag_indices] = 1.0 / torch.sqrt(
        node_strengths
    )

    # Compute the normalised weight matrix
    normalised_weight_matrix = torch.matmul(
        torch.matmul(inv_sqrt_node_strengths, weight_matrix), inv_sqrt_node_strengths
    )

    # Compute the communicability matrix
    communicability_matrix = torch.matrix_exp(normalised_weight_matrix)

    return communicability_matrix

gnm.utils.binary_betweenness_centrality(connectome, device=None)

Compute betweenness centrality for each node in binary networks.

Betweenness centrality quantifies the number of times a node acts as a bridge along the shortest path between two other nodes. It identifies nodes that control information flow in a network.

This function uses NetworkX for calculation and is intended for binary networks.

Parameters:

Name Type Description Default
matrices

Batch of binary adjacency matrices with shape [num_matrices, num_nodes, num_nodes]

required

Returns:

Type Description

Array of betweenness centralities for each node in each network with shape [num_matrices, num_nodes]

Examples:

>>> import torch
>>> from gnm.utils import binary_betweenness_centrality
>>> from gnm import defaults
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> binary_connectome = defaults.get_binary_network(device=DEVICE)
>>> betweenness = binary_betweenness_centrality(binary_connectome)
>>> betweenness.shape
torch.Size([1, 4])
Notes

This function converts PyTorch tensors to NumPy arrays for NetworkX processing, then converts the results back to PyTorch tensors. For large networks or batches, this may be computationally expensive.

See Also
  • evaluation.BetweennessKS: Binary evaluation criterion which compares the distribution of betweenness centralities between two binary networks.
Source code in src/gnm/utils/graph_properties.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
@jaxtyped(typechecker=typechecked)
def binary_betweenness_centrality(
    connectome: Float[torch.Tensor, "*batch num_nodes num_nodes"], 
    device=None):

    r"""Compute betweenness centrality for each node in binary networks.

    Betweenness centrality quantifies the number of times a node acts as a bridge along
    the shortest path between two other nodes. It identifies nodes that control information
    flow in a network.

    This function uses NetworkX for calculation and is intended for binary networks.

    Args:
        matrices:
            Batch of binary adjacency matrices with shape [num_matrices, num_nodes, num_nodes]

    Returns:
        Array of betweenness centralities for each node in each network with shape [num_matrices, num_nodes]

    Examples:
        >>> import torch
        >>> from gnm.utils import binary_betweenness_centrality
        >>> from gnm import defaults
        >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        >>> binary_connectome = defaults.get_binary_network(device=DEVICE)
        >>> betweenness = binary_betweenness_centrality(binary_connectome)
        >>> betweenness.shape
        torch.Size([1, 4])

    Notes:
        This function converts PyTorch tensors to NumPy arrays for NetworkX processing,
        then converts the results back to PyTorch tensors. For large networks or batches,
        this may be computationally expensive.

    See Also:
        - [`evaluation.BetweennessKS`][gnm.evaluation.BetweennessKS]: Binary evaluation criterion which compares the distribution of betweenness centralities between two binary networks.
    """

    if device is None:
        device = connectome.device

    binary_checks(connectome)

    batch_size = connectome.shape[0]
    num_nodes = connectome.shape[-1]  

    # Identity matrix over batches
    single_identity = torch.eye(num_nodes, device=device)
    batch_identity = single_identity.repeat(batch_size, 1, 1)  # I

    num_shortest_paths = connectome.clone() 
    num_shortest_paths_length_d = torch.zeros_like(connectome)
    num_shortest_paths_lengths_any = torch.zeros_like(connectome)
    length_shortest_path = connectome.clone() 

    # Self-connections have a shortest path of 1
    num_shortest_paths_lengths_any[batch_identity.bool()] = 1
    length_shortest_path[batch_identity.bool()] = 1

    for path_length in range(2, num_nodes + 1):
        num_shortest_paths = torch.bmm(num_shortest_paths, connectome)

        num_shortest_paths_length_d.copy_(num_shortest_paths)
        num_shortest_paths_length_d[length_shortest_path != 0] = 0

        # Update shortest path counts and lengths
        num_shortest_paths_lengths_any += num_shortest_paths_length_d
        length_shortest_path += path_length * (num_shortest_paths_length_d != 0)

        # Break if no new shortest paths are found
        if torch.all(num_shortest_paths_length_d == 0):
            break

    # Assign infinite length to disconnected edges
    length_shortest_path = torch.where(length_shortest_path == 0, torch.inf, length_shortest_path)
    length_shortest_path[batch_identity.bool()] = 0

    # Assign 1 to disconnected paths
    num_shortest_paths_lengths_any = torch.where(num_shortest_paths_lengths_any == 0, 1, num_shortest_paths_lengths_any)

    # Initialize dependency matrix
    dependency = torch.zeros((batch_size, num_nodes, num_nodes), device=device)

    for path_length in range(path_length-1, 1, -1):
        temporary_path_dependency = torch.bmm(
            ((length_shortest_path == path_length).float() * (1 + dependency) / (num_shortest_paths_lengths_any + 1e-10)),
            connectome.transpose(-1, -2)
        ) * ((length_shortest_path == (path_length - 1)).float() * num_shortest_paths_lengths_any)

        dependency += temporary_path_dependency

    return dependency.sum(dim=1)  # Sum over node dependencies

gnm.utils.binary_small_worldness(connectome, average_random_clustering=0.451, average_random_path_length=0.013)

Compute the small-worldness for each network in a batch.

Small-worldness quantifies the degree to which a network exhibits small-world properties, which are characterized by high clustering and short path lengths.

Parameters:

Name Type Description Default
connectome Float[Tensor, '*batch num_nodes num_nodes']

Binary adjacency matrix with shape [*batch, num_nodes, num_nodes]

required
average_random_clustering float

Average clustering coefficient of random networks.

0.451
average_random_path_length float

Average shortest path length of random networks.

0.013

Returns:

Type Description

Small-worldness for each network with shape [*batch]

Examples:

>>> import torch
>>> from gnm.utils import binary_small_worldness
>>> from gnm import defaults
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> binary_connectome = defaults.get_binary_network(device=DEVICE)
>>> small_worldness = binary_small_worldness(binary_connectome)
Source code in src/gnm/utils/graph_properties.py
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
@jaxtyped(typechecker=typechecked)
def binary_small_worldness(
    connectome: Float[torch.Tensor, "*batch num_nodes num_nodes"], 
    average_random_clustering=0.451, 
    average_random_path_length=0.013):
    r"""Compute the small-worldness for each network in a batch.

    Small-worldness quantifies the degree to which a network exhibits small-world properties,
    which are characterized by high clustering and short path lengths.

    Args:
        connectome:
            Binary adjacency matrix with shape [*batch, num_nodes, num_nodes]

        average_random_clustering (float): Average clustering coefficient of random networks.
        average_random_path_length (float): Average shortest path length of random networks.

    Returns:
        Small-worldness for each network with shape [*batch]

    Examples:
        >>> import torch
        >>> from gnm.utils import binary_small_worldness
        >>> from gnm import defaults
        >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        >>> binary_connectome = defaults.get_binary_network(device=DEVICE)
        >>> small_worldness = binary_small_worldness(binary_connectome)
    """

    warn("Using default values for average_random_clustering and average_random_path_length. " \
    "Consider recalculating using simulate_random_graph_clustering function.")

    binary_checks(connectome)

    binary_clustering = binary_clustering_coefficients(connectome)
    binary_characteristic_path_length = binary_characteristic_path_length(connectome)

    # Small-worldness (omega)
    small_worldness = (binary_clustering / average_random_clustering) / (binary_characteristic_path_length / average_random_path_length)
    return small_worldness

gnm.utils.weighted_small_worldness(connectome, average_random_clustering=0.451, average_random_path_length=0.013)

Calculates the weighted small-worldness (omega) of a connectome or a batch of connectomes. Small-worldness is a measure of how efficiently a network balances local clustering and global integration. This function computes the small-worldness based on the weighted clustering coefficients and the average shortest path length of the network. Args: connectome (Float[torch.Tensor, "*batch num_nodes num_nodes"]): A batch of adjacency matrices representing the connectomes. The tensor should have shape (batch_size, num_nodes, num_nodes) and contain edge weights. average_random_clustering (float, optional): The average clustering coefficient of a comparable random network. Defaults to 0.451. average_random_path_length (float, optional): The average shortest path length of a comparable random network. Defaults to 0.013. Returns: np.ndarray: A 1D numpy array containing the small-worldness (omega) values for each connectome in the batch. Raises: ValueError: If the input tensor does not have the expected shape or contains invalid data. Notes: - The function assumes that the input connectome is weighted and undirected. - Self-loops are removed from the graph before calculating shortest path lengths. - The weighted clustering coefficients are computed using a separate helper function weighted_clustering_coefficients. Example: >>> connectome = torch.rand(5, 10, 10) # Batch of 5 connectomes with 10 nodes each >>> small_worldness = weighted_small_worldness(connectome) >>> print(small_worldness)

Source code in src/gnm/utils/graph_properties.py
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
@jaxtyped(typechecker=typechecked)
def weighted_small_worldness(connectome: Float[torch.Tensor, "*batch num_nodes num_nodes"], 
                             average_random_clustering=0.451, 
                             average_random_path_length=0.013):

    """
    Calculates the weighted small-worldness (omega) of a connectome or a batch of connectomes.
    Small-worldness is a measure of how efficiently a network balances local clustering 
    and global integration. This function computes the small-worldness based on the 
    weighted clustering coefficients and the average shortest path length of the network.
    Args:
        connectome (Float[torch.Tensor, "*batch num_nodes num_nodes"]): 
            A batch of adjacency matrices representing the connectomes. The tensor 
            should have shape (batch_size, num_nodes, num_nodes) and contain edge weights.
        average_random_clustering (float, optional): 
            The average clustering coefficient of a comparable random network. Defaults to 0.451.
        average_random_path_length (float, optional): 
            The average shortest path length of a comparable random network. Defaults to 0.013.
    Returns:
        np.ndarray: 
            A 1D numpy array containing the small-worldness (omega) values for each connectome 
            in the batch.
    Raises:
        ValueError: If the input tensor does not have the expected shape or contains invalid data.
    Notes:
        - The function assumes that the input connectome is weighted and undirected.
        - Self-loops are removed from the graph before calculating shortest path lengths.
        - The weighted clustering coefficients are computed using a separate helper function 
        `weighted_clustering_coefficients`.
    Example:
        >>> connectome = torch.rand(5, 10, 10)  # Batch of 5 connectomes with 10 nodes each
        >>> small_worldness = weighted_small_worldness(connectome)
        >>> print(small_worldness)
    """

    # Real network measures
    connectome_np = connectome.detach().cpu().numpy()
    num_connectomes = connectome_np.shape[0]

    weighted_clustering = weighted_clustering_coefficients(connectome)
    weighted_clustering = weighted_clustering.detach().cpu().numpy()
    weighted_clustering_mean = np.mean(weighted_clustering, axis=1)

    small_worldness = []
    for i in range(num_connectomes):
        single_connectome = connectome_np[i, :, :]

        single_connectome = nx.from_numpy_array(single_connectome)
        G = nx.Graph(single_connectome)
        G.remove_edges_from(nx.selfloop_edges(G))

        clustering_mean = weighted_clustering_mean[i]
        shortest_path_length_mean = nx.average_shortest_path_length(G, weight="weight")

        # Small-worldness (omega)
        omega = (clustering_mean / average_random_clustering) / (shortest_path_length_mean / average_random_path_length)
        small_worldness.append(omega)

    small_worldness = np.array(small_worldness)

    return small_worldness

gnm.utils.generate_random_networks(num_nodes, density, seed, n=1, weighted=False)

Create a random graph with the given number of nodes and density.

Parameters:

Name Type Description Default
num_nodes int

Number of nodes in the graph.

required
density float

Density of the graph (between 0 and 1).

required
seed int

Random seed for reproducibility.

required
n int

Number of graphs to create.

1
weighted bool

If True, create a weighted graph.

False

Returns:

Name Type Description
Tensor Float[Tensor, 'n num_nodes num_nodes']

Adjacency matrices of shape (n, num_nodes, num_nodes)

Source code in src/gnm/utils/control.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@jaxtyped(typechecker=typechecked)
def generate_random_networks(
    num_nodes: int, 
    density: Optional[Float], 
    seed: int, 
    n: int = 1, 
    weighted: bool = False
) -> Float[torch.Tensor, "n num_nodes num_nodes"]:
    """Create a random graph with the given number of nodes and density.

    Args:
        num_nodes (int): Number of nodes in the graph.
        density (float): Density of the graph (between 0 and 1).
        seed (int): Random seed for reproducibility.
        n (int): Number of graphs to create.
        weighted (bool): If True, create a weighted graph.

    Returns:
        Tensor: Adjacency matrices of shape (n, num_nodes, num_nodes)
    """

    torch.manual_seed(seed)

    graphs = torch.bernoulli(torch.full((n, num_nodes, num_nodes), density)).int()

    # Make symmetric, no self-loops
    graphs = torch.triu(graphs, diagonal=1)
    graphs = graphs + graphs.transpose(1, 2)

    if weighted:
        weights = torch.rand(n, num_nodes, num_nodes)
        weights = torch.triu(weights, diagonal=1)
        weights = weights + weights.transpose(1, 2)
        graphs = graphs * weights

    return graphs

gnm.utils.characteristic_path_length(connectome)

Compute the characteristic path length for each binary network.

Source code in src/gnm/utils/graph_properties.py
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
def characteristic_path_length(
    connectome: Float[torch.Tensor, "*batch num_nodes num_nodes"]
) -> Float[torch.Tensor, "*batch"]:
    r"""Compute the characteristic path length for each binary network."""
    binary_checks(connectome)

    batch_shape = connectome.shape[:-2]
    n_nodes = connectome.shape[-1]

    connectome = connectome.clone()
    connectome[connectome == 0] = 1e9

    # Set diagonal to 0 (no self-distance)
    diag_idx = torch.arange(n_nodes, device=connectome.device)
    connectome[..., diag_idx, diag_idx] = 0

    # Floyd-Warshall algorithm: iteratively updates the shortest paths between all pairs of nodes using intermediate nodes
    for k in range(n_nodes):
        connectome = torch.minimum(
            connectome,
            connectome[..., :, k].unsqueeze(-1) + connectome[..., k, :].unsqueeze(-2)
        )

    # After shortest paths computed:
    # Mask diagonal (self-distances)
    mask = ~torch.eye(n_nodes, dtype=bool, device=connectome.device)

    shortest_paths = connectome[..., mask].reshape(*batch_shape, n_nodes, n_nodes - 1)

    # Mean over all node pairs
    path_length = shortest_paths.mean(dim=(-1, -2))  # mean over nodes and targets

    return path_length