Skip to content

Commit

Permalink
Add a kernel_regularizer kwarg to all Classification and Regression…
Browse files Browse the repository at this point in the history
… tasks,

which gets forwarded to the output (logits) Dense layer, aligning it with
the options for weight regularization in MtAlbis and other models.

PiperOrigin-RevId: 633570099
  • Loading branch information
arnoegw authored and tensorflower-gardener committed May 14, 2024
1 parent fa34af9 commit 126dc79
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 92 deletions.
70 changes: 56 additions & 14 deletions tensorflow_gnn/runner/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def __init__(
*,
name: str = "classification_logits",
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None):
label_feature_name: Optional[str] = None,
kernel_regularizer: Any = None,
):
"""Sets `Task` parameters.
Args:
Expand All @@ -120,6 +122,9 @@ def __init__(
label_feature_name: A label feature name for readout from the auxiliary
'_readout' node set. Readout does not mutate the input `GraphTensor`.
Mutually exclusive with `label_fn`.
kernel_regularizer: Can be set to a `kernel_regularizer` as understood
by `tf.keras.layers.Dense` etc. to perform weight regularization of the
classification logits layer.
"""
if (label_fn is None) == (label_feature_name is None):
raise ValueError(
Expand All @@ -131,6 +136,7 @@ def __init__(
self._name = name
self._label_fn = label_fn
self._label_feature_name = label_feature_name
self._kernel_regularizer = kernel_regularizer

@abc.abstractmethod
def gather_activations(self, inputs: GraphTensor) -> Field:
Expand All @@ -148,7 +154,7 @@ def predict(self, inputs: tfgnn.GraphTensor) -> interfaces.Predictions:
tfgnn.check_scalar_graph_tensor(inputs, name="Classification")
activations = self.gather_activations(inputs)
logits = tf.keras.layers.Dense(
self._units,
self._units, kernel_regularizer=self._kernel_regularizer,
name=self._name)(activations)
return logits

Expand Down Expand Up @@ -374,7 +380,8 @@ def __init__(self,
reduce_type: str = "mean",
name: str = "classification_logits",
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None):
label_feature_name: Optional[str] = None,
kernel_regularizer: Any = None):
"""Graph binary (or multi-label) classification.
This task performs binary classification (or multiple independent ones:
Expand All @@ -394,6 +401,9 @@ def __init__(self,
label_feature_name: A label feature name for readout from the auxiliary
'_readout' node set. Readout does not mutate the input `GraphTensor`.
Mutually exclusive with `label_fn`.
kernel_regularizer: Can be set to a `kernel_regularizer` as understood
by `tf.keras.layers.Dense` etc. to perform weight regularization of the
classification logits layer.
"""
super().__init__(
node_set_name,
Expand All @@ -402,7 +412,9 @@ def __init__(self,
reduce_type=reduce_type,
name=name,
label_fn=label_fn,
label_feature_name=label_feature_name)
label_feature_name=label_feature_name,
kernel_regularizer=kernel_regularizer,
)


class GraphMulticlassClassification(_GraphClassification,
Expand All @@ -419,7 +431,8 @@ def __init__(self,
reduce_type: str = "mean",
name: str = "classification_logits",
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None):
label_feature_name: Optional[str] = None,
kernel_regularizer: Any = None):
"""Graph multiclass classification from pooled node states.
Args:
Expand All @@ -439,6 +452,9 @@ def __init__(self,
label_feature_name: A label feature name for readout from the auxiliary
'_readout' node set. Readout does not mutate the input `GraphTensor`.
Mutually exclusive with `label_fn`.
kernel_regularizer: Can be set to a `kernel_regularizer` as understood
by `tf.keras.layers.Dense` etc. to perform weight regularization of the
classification logits layer.
"""
super().__init__(
node_set_name,
Expand All @@ -449,7 +465,9 @@ def __init__(self,
reduce_type=reduce_type,
name=name,
label_fn=label_fn,
label_feature_name=label_feature_name)
label_feature_name=label_feature_name,
kernel_regularizer=kernel_regularizer,
)


class RootNodeBinaryClassification(_RootNodeClassification,
Expand All @@ -463,7 +481,8 @@ def __init__(self,
state_name: str = tfgnn.HIDDEN_STATE,
name: str = "classification_logits",
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None):
label_feature_name: Optional[str] = None,
kernel_regularizer: Any = None):
"""Root node binary (or multi-label) classification.
This task performs binary classification (or multiple independent ones:
Expand All @@ -486,14 +505,19 @@ def __init__(self,
label_feature_name: A label feature name for readout from the auxiliary
'_readout' node set. Readout does not mutate the input `GraphTensor`.
Mutually exclusive with `label_fn`.
kernel_regularizer: Can be set to a `kernel_regularizer` as understood
by `tf.keras.layers.Dense` etc. to perform weight regularization of the
classification logits layer.
"""
super().__init__(
node_set_name,
units=units,
state_name=state_name,
name=name,
label_fn=label_fn,
label_feature_name=label_feature_name)
label_feature_name=label_feature_name,
kernel_regularizer=kernel_regularizer,
)


class RootNodeMulticlassClassification(_RootNodeClassification,
Expand All @@ -509,7 +533,8 @@ def __init__(self,
state_name: str = tfgnn.HIDDEN_STATE,
name: str = "classification_logits",
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None):
label_feature_name: Optional[str] = None,
kernel_regularizer: Any = None):
"""Root node multiclass classification.
This task can be used on graph datasets without a readout structure.
Expand All @@ -532,6 +557,9 @@ def __init__(self,
label_feature_name: A label feature name for readout from the auxiliary
'_readout' node set. Readout does not mutate the input `GraphTensor`.
Mutually exclusive with `label_fn`.
kernel_regularizer: Can be set to a `kernel_regularizer` as understood
by `tf.keras.layers.Dense` etc. to perform weight regularization of the
classification logits layer.
"""
super().__init__(
node_set_name,
Expand All @@ -541,7 +569,9 @@ def __init__(self,
state_name=state_name,
name=name,
label_fn=label_fn,
label_feature_name=label_feature_name)
label_feature_name=label_feature_name,
kernel_regularizer=kernel_regularizer,
)


class NodeBinaryClassification(_NodeClassification, _BinaryClassification):
Expand All @@ -556,7 +586,8 @@ def __init__(self,
validate: bool = True,
name: str = "classification_logits",
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None):
label_feature_name: Optional[str] = None,
kernel_regularizer: Any = None):
"""Node binary (or multi-label) classification.
This task performs binary classification (or multiple independent ones:
Expand All @@ -582,6 +613,9 @@ def __init__(self,
label_feature_name: A label feature name for readout from the auxiliary
'_readout' node set. Readout does not mutate the input `GraphTensor`.
Mutually exclusive with `label_fn`.
kernel_regularizer: Can be set to a `kernel_regularizer` as understood
by `tf.keras.layers.Dense` etc. to perform weight regularization of the
classification logits layer.
"""
super().__init__(
key,
Expand All @@ -591,7 +625,9 @@ def __init__(self,
validate=validate,
name=name,
label_fn=label_fn,
label_feature_name=label_feature_name)
label_feature_name=label_feature_name,
kernel_regularizer=kernel_regularizer,
)


class NodeMulticlassClassification(_NodeClassification,
Expand All @@ -609,7 +645,8 @@ def __init__(self,
per_class_statistics: bool = False,
name: str = "classification_logits",
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None):
label_feature_name: Optional[str] = None,
kernel_regularizer: Any = None):
"""Node multiclass classification via structured readout.
Args:
Expand All @@ -635,6 +672,9 @@ def __init__(self,
label_feature_name: A label feature name for readout from the auxiliary
'_readout' node set. Readout does not mutate the input `GraphTensor`.
Mutually exclusive with `label_fn`.
kernel_regularizer: Can be set to a `kernel_regularizer` as understood
by `tf.keras.layers.Dense` etc. to perform weight regularization of the
classification logits layer.
"""
super().__init__(
key,
Expand All @@ -646,4 +686,6 @@ def __init__(self,
per_class_statistics=per_class_statistics,
name=name,
label_fn=label_fn,
label_feature_name=label_feature_name)
label_feature_name=label_feature_name,
kernel_regularizer=kernel_regularizer,
)
47 changes: 33 additions & 14 deletions tensorflow_gnn/runner/tasks/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def fn(inputs):
return fn


def l2(rate):
return tf.keras.regularizers.L2(rate)


def add_readout_from_first_node(gt: GraphTensor) -> GraphTensor:
return tfgnn.add_readout_from_first_node(
gt,
Expand Down Expand Up @@ -205,63 +209,76 @@ def test_preprocess(
testcase_name="GraphBinaryClassification",
task=classification.GraphBinaryClassification(
"nodes",
label_fn=label_fn(2)),
label_fn=label_fn(2),
kernel_regularizer=l2(0.125)),
gt=TEST_GRAPH_TENSOR,
expected_loss=tf.keras.losses.BinaryCrossentropy,
expected_shape=tf.TensorShape((None, 1))),
expected_shape=tf.TensorShape((None, 1)),
expected_l2_regularization=0.125),
dict(
testcase_name="GraphMulticlassClassification",
task=classification.GraphMulticlassClassification(
"nodes",
num_classes=4,
label_feature_name="labels"),
label_feature_name="labels",
kernel_regularizer=l2(0.25)),
gt=context_readout_into_feature(4, TEST_GRAPH_TENSOR),
expected_loss=tf.keras.losses.SparseCategoricalCrossentropy,
expected_shape=tf.TensorShape((None, 4))),
expected_shape=tf.TensorShape((None, 4)),
expected_l2_regularization=0.25),
dict(
testcase_name="RootNodeBinaryClassification",
task=classification.RootNodeBinaryClassification(
"nodes",
label_fn=label_fn(2)),
label_fn=label_fn(2),
kernel_regularizer=l2(0.5)),
gt=TEST_GRAPH_TENSOR,
expected_loss=tf.keras.losses.BinaryCrossentropy,
expected_shape=tf.TensorShape((None, 1))),
expected_shape=tf.TensorShape((None, 1)),
expected_l2_regularization=0.5),
dict(
testcase_name="RootNodeMulticlassClassification",
task=classification.RootNodeMulticlassClassification(
"nodes",
num_classes=3,
label_feature_name="labels"),
label_feature_name="labels",
kernel_regularizer=l2(0.75)),
gt=context_readout_into_feature(3, TEST_GRAPH_TENSOR),
expected_loss=tf.keras.losses.SparseCategoricalCrossentropy,
expected_shape=tf.TensorShape((None, 3))),
expected_shape=tf.TensorShape((None, 3)),
expected_l2_regularization=0.75),
dict(
testcase_name="NodeBinaryClassification",
task=classification.NodeBinaryClassification(
READOUT_KEY,
label_fn=label_fn(2)),
label_fn=label_fn(2),
kernel_regularizer=l2(1.0)),
gt=add_readout_from_first_node(TEST_GRAPH_TENSOR),
expected_loss=tf.keras.losses.BinaryCrossentropy,
expected_shape=tf.TensorShape((None, 1))),
expected_shape=tf.TensorShape((None, 1)),
expected_l2_regularization=1.0),
dict(
testcase_name="NodeMulticlassClassification",
task=classification.NodeMulticlassClassification(
READOUT_KEY,
num_classes=3,
label_feature_name="labels"),
label_feature_name="labels",
kernel_regularizer=l2(0.375)),
gt=add_readout_from_first_node(context_readout_into_feature(
3,
TEST_GRAPH_TENSOR)),
expected_loss=tf.keras.losses.SparseCategoricalCrossentropy,
expected_shape=tf.TensorShape((None, 3))),
expected_shape=tf.TensorShape((None, 3)),
expected_l2_regularization=0.375),
])
def test_predict(
self,
task: interfaces.Task,
gt: GraphTensor,
expected_loss: Type[tf.keras.losses.Loss],
expected_shape: tf.TensorShape):
# Assert head readout, activation and shape.
expected_shape: tf.TensorShape,
expected_l2_regularization: float):
# Assert head readout, activation, shape and regularization.
inputs = tf.keras.layers.Input(type_spec=gt.spec)
model = tf.keras.Model(inputs, task.predict(inputs))
self.assertLen(model.layers, 3)
Expand All @@ -279,6 +296,8 @@ def test_predict(
_, _, dense = model.layers
self.assertEqual(dense.get_config()["activation"], "linear")
self.assertTrue(expected_shape.is_compatible_with(dense.output_shape))
self.assertEqual(dense.kernel_regularizer.get_config()["l2"],
expected_l2_regularization)

# Assert losses.
loss = task.losses()
Expand Down
Loading

0 comments on commit 126dc79

Please sign in to comment.