Skip to content

Commit

Permalink
Add clrs_text to __init__.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649677897
  • Loading branch information
RerRayne authored and copybara-github committed Jul 5, 2024
1 parent a5314f3 commit 9bf6807
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 43 deletions.
9 changes: 9 additions & 0 deletions clrs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,29 @@
"""The CLRS Algorithmic Reasoning Benchmark."""

from clrs import models

from clrs._src import algorithms
from clrs._src import clrs_text
from clrs._src import decoders
from clrs._src import processors

from clrs._src.dataset import chunkify
from clrs._src.dataset import CLRSDataset
from clrs._src.dataset import create_chunked_dataset
from clrs._src.dataset import create_dataset
from clrs._src.dataset import get_clrs_folder
from clrs._src.dataset import get_dataset_gcp_url

from clrs._src.evaluation import evaluate
from clrs._src.evaluation import evaluate_hints

from clrs._src.model import Model

from clrs._src.probing import DataPoint
from clrs._src.probing import predecessor_to_cyclic_predecessor_and_first

from clrs._src.processors import get_processor_factory

from clrs._src.samplers import build_sampler
from clrs._src.samplers import CLRS30
from clrs._src.samplers import Features
Expand All @@ -40,6 +48,7 @@
from clrs._src.samplers import process_random_pos
from clrs._src.samplers import Sampler
from clrs._src.samplers import Trajectory

from clrs._src.specs import ALGO_IDX_INPUT_NAME
from clrs._src.specs import CLRS_30_ALGS_SETTINGS
from clrs._src.specs import Location
Expand Down
17 changes: 17 additions & 0 deletions clrs/_src/clrs_text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The CLRS Text Algorithmic Reasoning Benchmark."""

from clrs._src.clrs_text import clrs_utils
90 changes: 47 additions & 43 deletions clrs/_src/clrs_text/clrs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# ==============================================================================
"""Functions to create text versions of CLRS data."""
from typing import Any, Optional

import clrs
from clrs._src import samplers
from clrs._src import specs
import numpy as np


Expand Down Expand Up @@ -78,7 +78,7 @@

def format_clrs_example(
algo: str,
sample: clrs.Feedback,
sample: samplers.Feedback,
use_hints: bool = False,
) -> tuple[str, str]:
"""Formats CLRS example into prompt for the LLM.
Expand Down Expand Up @@ -112,7 +112,7 @@ def format_clrs_example(

def _get_output_names(
algo_name: str,
spec: clrs.Spec,
spec: specs.Spec,
use_hints: bool,
) -> list[str]:
"""Gets the output names for a CLRS algorithm."""
Expand All @@ -124,12 +124,12 @@ def _get_output_names(
return [
spec_name
for spec_name in spec
if spec[spec_name][0] == clrs.Stage.OUTPUT
if spec[spec_name][0] == specs.Stage.OUTPUT
]


def _get_output_str(
sample: clrs.Feedback, spec, algo_name: str, use_hints: bool
sample: samplers.Feedback, spec, algo_name: str, use_hints: bool
) -> list[str]:
"""Gets the output string for a CLRS algorithm."""
if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER and use_hints:
Expand Down Expand Up @@ -157,7 +157,7 @@ def _get_output_str(

def sample_to_str(
algo: str,
sample: clrs.Feedback,
sample: samplers.Feedback,
use_hints: bool = False,
) -> tuple[str, str, str, bool]:
"""Converts a CLRS sample into input and output strings.
Expand Down Expand Up @@ -206,7 +206,7 @@ def sample_to_str(
Returns:
A 3-tuple of (input, output_names, output) strings.
"""
spec = clrs.SPECS[algo]
spec = specs.SPECS[algo]

# Create input prompt.
input_strs = _create_input_feature_strs(spec, sample.features.inputs)
Expand Down Expand Up @@ -252,16 +252,16 @@ def sample_to_str(


def _create_input_feature_strs(
spec: clrs.Spec,
inputs: clrs.Features,
spec: specs.Spec,
inputs: samplers.Features,
) -> list[str]:
"""Extracts input features and convert them into strings."""
input_strs = []
for spec_name in spec:

stage, _, _ = spec[spec_name] # (stage, location, type)

if stage != clrs.Stage.INPUT:
if stage != specs.Stage.INPUT:
continue

if _do_not_include_input_in_text(spec_name, spec):
Expand All @@ -279,16 +279,16 @@ def _create_input_feature_strs(


def _create_output_feature_strs(
spec: clrs.Spec,
inputs: clrs.Features,
outputs: clrs.Features,
spec: specs.Spec,
inputs: samplers.Features,
outputs: samplers.Features,
) -> list[str]:
"""Extracts output features and convert them into strings."""
output_strs = []
for spec_name in spec:
stage, _, _ = spec[spec_name]

if stage != clrs.Stage.OUTPUT:
if stage != specs.Stage.OUTPUT:
continue

x = _get_feature_by_name(outputs, spec_name).data
Expand Down Expand Up @@ -339,9 +339,9 @@ def _format_hint(hints: list[str], algo_name: str) -> str:

def _create_hint_feature_strs(
algo_name: str,
spec: clrs.Spec,
inputs: clrs.Features,
hints: clrs.Features,
spec: specs.Spec,
inputs: samplers.Features,
hints: samplers.Features,
output_names: list[str],
) -> tuple[str, str, bool]:
"""Extracts hint features and convert them into strings."""
Expand Down Expand Up @@ -405,10 +405,10 @@ def _create_hint_feature_strs(

def _feature_to_str(
name: str,
spec: clrs.Spec,
spec: specs.Spec,
x: np.ndarray,
with_name: bool,
inputs: Optional[clrs.Features] = None,
inputs: Optional[samplers.Features] = None,
edge_masks_as_edge_list: bool = False,
) -> str:
"""Converts a numerical CLRS feature into a string."""
Expand All @@ -421,22 +421,22 @@ def _feature_to_str(
x = x[0]
unused_stage, location, typ_ = spec[name]
match location:
case clrs.Location.NODE:
case specs.Location.NODE:
output = _convert_node_features_to_str(
x=x,
spec_name=name,
spec=spec,
spec_type=typ_,
inputs=inputs,
)
case clrs.Location.GRAPH:
case specs.Location.GRAPH:
output = _convert_graph_features_to_str(
x=x,
spec_name=name,
spec=spec,
spec_type=typ_,
)
case clrs.Location.EDGE:
case specs.Location.EDGE:
output = _convert_edge_features_to_str(
x=x,
spec_name=name,
Expand Down Expand Up @@ -469,13 +469,13 @@ def predecessors_to_order(x: np.ndarray) -> np.ndarray:
def _convert_node_features_to_str(
x: np.ndarray,
spec_name: str,
spec: clrs.Spec,
spec: specs.Spec,
spec_type: str,
inputs: Optional[clrs.Features] = None,
inputs: Optional[samplers.Features] = None,
) -> str:
"""Converts node features into string."""
match spec_type:
case clrs.Type.SHOULD_BE_PERMUTATION:
case specs.Type.SHOULD_BE_PERMUTATION:
# For the text version of CLRS, if the output is a permutation, we present
# the "key" input values in the order given by the permutation.
nonsorted_values = _get_feature_by_name(inputs, 'key').data[0]
Expand All @@ -488,15 +488,15 @@ def _convert_node_features_to_str(
SEQUENCE_SEPARATOR.join([f'{scalar:.3g}' for scalar in sorted_values])
)

case clrs.Type.MASK_ONE:
case specs.Type.MASK_ONE:
[index] = x.nonzero()[0]
return f'{index}'

case clrs.Type.SCALAR:
case specs.Type.SCALAR:
return _bracket(SEQUENCE_SEPARATOR.join([f'{a:.3g}' for a in x]))

case clrs.Type.MASK | clrs.Type.POINTER | clrs.Type.CATEGORICAL:
if spec_type == clrs.Type.CATEGORICAL:
case specs.Type.MASK | specs.Type.POINTER | specs.Type.CATEGORICAL:
if spec_type == specs.Type.CATEGORICAL:
categories = np.argmax(x, axis=-1)
int_output = categories
else:
Expand All @@ -510,20 +510,24 @@ def _convert_node_features_to_str(
def _convert_graph_features_to_str(
x: np.ndarray,
spec_name: str,
spec: clrs.Spec,
spec: specs.Spec,
spec_type: str,
) -> str:
"""Converts graph features into string."""
match spec_type:
case clrs.Type.SCALAR:
case specs.Type.SCALAR:
return f'{x:.3f}'

case clrs.Type.CATEGORICAL:
case specs.Type.CATEGORICAL:
categories = np.argmax(x, axis=-1)
return f'{categories}'

case _:
if spec_type in [clrs.Type.MASK, clrs.Type.MASK_ONE, clrs.Type.POINTER]:
if spec_type in [
specs.Type.MASK,
specs.Type.MASK_ONE,
specs.Type.POINTER,
]:
return f'{x.astype(int)}'
else:
raise KeyError(f'Feature type not supported in spec {spec[spec_name]}')
Expand All @@ -532,24 +536,24 @@ def _convert_graph_features_to_str(
def _convert_edge_features_to_str(
x: np.ndarray,
spec_name: str,
spec: clrs.Spec,
spec: specs.Spec,
spec_type: str,
edge_masks_as_edge_list: bool,
):
"""Converts edge features into string."""

if edge_masks_as_edge_list:
if spec_type == clrs.Type.MASK or (
spec_type == clrs.Type.SCALAR and _is_binary(x)
if spec_type == specs.Type.MASK or (
spec_type == specs.Type.SCALAR and _is_binary(x)
):
edges = list(zip(*np.nonzero(x > 0)))
return DEFAULT_SEPARATOR.join([f'({x},{y})' for x, y in edges])
else:
match spec_type:
case clrs.Type.POINTER | clrs.Type.MASK | clrs.Type.CATEGORICAL:
if spec_type == clrs.Type.CATEGORICAL:
case specs.Type.POINTER | specs.Type.MASK | specs.Type.CATEGORICAL:
if spec_type == specs.Type.CATEGORICAL:
# lcs_length includes masked elements where the category is -1
mask = np.any(x == clrs.OutputClass.MASKED, axis=-1)
mask = np.any(x == specs.OutputClass.MASKED, axis=-1)
categories = np.argmax(x, axis=-1)
categories[mask] = -1
int_output = categories
Expand All @@ -562,14 +566,14 @@ def _convert_edge_features_to_str(
),
)

case clrs.Type.SCALAR:
case specs.Type.SCALAR:
row_to_str = lambda r: _bracket(' '.join([f'{a:.3g}' for a in r]))
return _bracket(DEFAULT_SEPARATOR.join([row_to_str(r) for r in x]))

raise KeyError(f'Feature type not supported in spec {spec[spec_name]}')


def _get_feature_by_name(examples: clrs.Features, spec_name: str) -> Any:
def _get_feature_by_name(examples: samplers.Features, spec_name: str) -> Any:
filtered_inputs = [
example for example in examples if example.name == spec_name
]
Expand All @@ -590,7 +594,7 @@ def _bracket(s: str) -> str:
return f'[{s}]'


def _do_not_include_input_in_text(spec_name: str, spec: clrs.Spec) -> bool:
def _do_not_include_input_in_text(spec_name: str, spec: specs.Spec) -> bool:
if spec_name == 'pos':
return True
if spec_name == 'adj' and 'A' in spec:
Expand Down

0 comments on commit 9bf6807

Please sign in to comment.