diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 45f428eaa5..30ef248596 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -244,6 +244,23 @@ cc_library( ], ) +cc_test( + name = "constraint_violation_test", + size = "small", + srcs = ["constraint_violation_test.cc"], + deps = [ + ":constraint_violation", + ":cp_model_cc_proto", + ":cp_model_utils", + "//ortools/base", + "//ortools/base:dump_vars", + "//ortools/base:gmock_main", + "//ortools/base:parse_test_proto", + "//ortools/util:sorted_interval_list", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "feasibility_jump", srcs = ["feasibility_jump.cc"], @@ -664,6 +681,17 @@ cc_library( ], ) +cc_library( + name = "cp_model_test_utils", + srcs = ["cp_model_test_utils.cc"], + hdrs = ["cp_model_test_utils.h"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_utils", + "@com_google_absl//absl/random", + ], +) + proto_library( name = "boolean_problem_proto", srcs = ["boolean_problem.proto"], @@ -682,9 +710,8 @@ cc_library( ":cp_model_cc_proto", ":cp_model_utils", ":util", - "//ortools/base", "//ortools/base:strong_vector", - "//ortools/base:types", + "//ortools/base:timer", "//ortools/util:bitset", "//ortools/util:logging", "//ortools/util:saturated_arithmetic", @@ -695,7 +722,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/random", "@com_google_absl//absl/random:bit_gen_ref", "@com_google_absl//absl/random:distributions", "@com_google_absl//absl/strings", @@ -784,6 +810,36 @@ cc_test( ], ) +cc_library( + name = "cp_model_table", + srcs = ["cp_model_table.cc"], + hdrs = ["cp_model_table.h"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_utils", + ":presolve_context", + "//ortools/base:stl_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "cp_model_table_test", + size = "small", + srcs = ["cp_model_table_test.cc"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_table", + ":sat_parameters_cc_proto", + "//ortools/base:gmock_main", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cp_model_presolve", srcs = [ @@ -799,6 +855,7 @@ cc_library( ":cp_model_expand", ":cp_model_mapping", ":cp_model_symmetries", + ":cp_model_table", ":cp_model_utils", ":diffn_util", ":diophantine", @@ -849,6 +906,28 @@ cc_library( ], ) +cc_test( + name = "cp_model_presolve_random_test", + size = "medium", + srcs = ["cp_model_presolve_random_test.cc"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_solver", + ":cp_model_utils", + ":sat_parameters_cc_proto", + "//ortools/base", + "//ortools/base:file", + "//ortools/base:gmock_main", + "//ortools/base:path", + "//ortools/util:sorted_interval_list", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "cp_model_postsolve", srcs = [ @@ -888,6 +967,7 @@ cc_library( deps = [ ":cp_model_cc_proto", ":cp_model_checker", + ":cp_model_table", ":cp_model_utils", ":presolve_context", ":sat_parameters_cc_proto", @@ -1154,6 +1234,22 @@ cc_library( ], ) +cc_test( + name = "sat_decision_test", + size = "small", + srcs = ["sat_decision_test.cc"], + deps = [ + ":model", + ":sat_base", + ":sat_decision", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + "@com_google_absl//absl/random", + ], +) + cc_library( name = "clause", srcs = ["clause.cc"], @@ -1187,6 +1283,23 @@ cc_library( ], ) +cc_test( + name = "clause_test", + size = "small", + srcs = ["clause_test.cc"], + deps = [ + ":clause", + ":model", + ":sat_base", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/random", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "simplification", srcs = ["simplification.cc"], @@ -1454,6 +1567,20 @@ cc_library( ], ) +cc_test( + name = "lb_tree_search_test", + size = "small", + srcs = ["lb_tree_search_test.cc"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_solver", + ":cp_model_test_utils", + ":lb_tree_search", + ":sat_parameters_cc_proto", + "//ortools/base:gmock_main", + ], +) + cc_library( name = "pseudo_costs", srcs = ["pseudo_costs.cc"], @@ -2022,6 +2149,28 @@ cc_library( ], ) +cc_test( + name = "boolean_problem_test", + size = "small", + srcs = [ + "boolean_problem_test.cc", + "opb_reader.h", + ], + deps = [ + ":boolean_problem", + ":boolean_problem_cc_proto", + "//ortools/algorithms:sparse_permutation", + "//ortools/base", + "//ortools/base:file", + "//ortools/base:gmock_main", + "//ortools/base:path", + "//ortools/util:filelineiter", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "linear_relaxation", srcs = ["linear_relaxation.cc"], @@ -2172,6 +2321,29 @@ cc_library( ], ) +cc_test( + name = "linear_programming_constraint_test", + srcs = ["linear_programming_constraint_test.cc"], + deps = [ + ":cp_model_solver", + ":integer", + ":integer_search", + ":linear_constraint", + ":linear_constraint_manager", + ":linear_programming_constraint", + ":model", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/base:mathutil", + "//ortools/lp_data:base", + "//ortools/util:strong_integers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "linear_constraint_manager", srcs = ["linear_constraint_manager.cc"], @@ -2604,6 +2776,31 @@ cc_library( ], ) +cc_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_solver", + ":cp_model_utils", + ":sat_base", + ":sat_parameters_cc_proto", + ":util", + "//ortools/base", + "//ortools/base:gmock_main", + "//ortools/base:mathutil", + "//ortools/util:random_engine", + "//ortools/util:sorted_interval_list", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", + ], +) + cc_library( name = "stat_tables", srcs = ["stat_tables.cc"], @@ -3086,6 +3283,25 @@ cc_library( ], ) +cc_test( + name = "feasibility_pump_test", + size = "small", + srcs = ["feasibility_pump_test.cc"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_loader", + ":cp_model_mapping", + ":feasibility_pump", + ":integer", + ":linear_constraint", + ":model", + ":sat_base", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base:gmock_main", + ], +) + cc_library( name = "rins", srcs = ["rins.cc"], @@ -3249,6 +3465,25 @@ cc_library( ], ) +cc_test( + name = "sat_cnf_reader_test", + size = "small", + srcs = [ + "sat_cnf_reader_test.cc", + ], + deps = [ + ":boolean_problem", + ":boolean_problem_cc_proto", + ":sat_cnf_reader", + "//ortools/base:file", + "//ortools/base:gmock_main", + "//ortools/base:path", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_library( name = "cp_model_symmetries", srcs = ["cp_model_symmetries.cc"], @@ -3288,6 +3523,24 @@ cc_library( ], ) +cc_test( + name = "cp_model_symmetries_test", + srcs = ["cp_model_symmetries_test.cc"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_symmetries", + ":model", + ":presolve_context", + ":sat_parameters_cc_proto", + "//ortools/algorithms:sparse_permutation", + "//ortools/base:gmock_main", + "//ortools/base:parse_test_proto", + "//ortools/util:logging", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "swig_helper", srcs = ["swig_helper.cc"], @@ -3436,6 +3689,21 @@ cc_test( ], ) +cc_test( + name = "diophantine_test", + srcs = ["diophantine_test.cc"], + deps = [ + ":diophantine", + ":util", + "//ortools/base:gmock_main", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_test( name = "inclusion_test", size = "small", diff --git a/ortools/sat/boolean_problem_test.cc b/ortools/sat/boolean_problem_test.cc new file mode 100644 index 0000000000..6a6a476862 --- /dev/null +++ b/ortools/sat/boolean_problem_test.cc @@ -0,0 +1,186 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/boolean_problem.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gtest/gtest.h" +#include "ortools/algorithms/sparse_permutation.h" +#include "ortools/base/helpers.h" +#include "ortools/base/options.h" +#include "ortools/base/path.h" +#include "ortools/sat/boolean_problem.pb.h" +#include "ortools/sat/opb_reader.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(ValidateBooleanProblemTest, Ok) { + std::string file = + "min: 1 x1 1 x2 ;\n" + "1 x1 1 x2 >= 1 ;\n" + "1 x1 1 x2 >= 1 ;\n"; + LinearBooleanProblem problem; + OpbReader reader; + const std::string filename = file::JoinPath(::testing::TempDir(), "file.opb"); + CHECK_OK(file::SetContents(filename, file, file::Defaults())); + CHECK(reader.Load(filename, &problem)); + EXPECT_TRUE(ValidateBooleanProblem(problem).ok()); +} + +TEST(ValidateBooleanProblemTest, ZeroCoefficients) { + std::string file = + "min: 1 x1 1 x2 ;\n" + "1 x1 0 x2 >= 1 ;\n" + "1 x1 1 x2 >= 1 ;\n"; + LinearBooleanProblem problem; + OpbReader reader; + const std::string filename = + file::JoinPath(::testing::TempDir(), "file2.opb"); + CHECK_OK(file::SetContents(filename, file, file::Defaults())); + CHECK(reader.Load(filename, &problem)); + EXPECT_FALSE(ValidateBooleanProblem(problem).ok()); +} + +TEST(ValidateBooleanProblemTest, DuplicateEntries) { + std::string file = + "min: 1 x1 1 x2 ;\n" + "1 x1 1 x2 1 x1 >= 1 ;\n" + "1 x1 1 x2 >= 1 ;\n"; + LinearBooleanProblem problem; + OpbReader reader; + const std::string filename = + file::JoinPath(::testing::TempDir(), "file3.opb"); + CHECK_OK(file::SetContents(filename, file, file::Defaults())); + CHECK(reader.Load(filename, &problem)); + EXPECT_FALSE(ValidateBooleanProblem(problem).ok()); +} + +void FindSymmetries( + absl::string_view file, + std::vector>* generators) { + LinearBooleanProblem problem; + OpbReader reader; + static int counter = 4; + ++counter; + const std::string filename = file::JoinPath( + ::testing::TempDir(), absl::StrCat("file", counter, ".opb")); + CHECK_OK(file::SetContents(filename, file, file::Defaults())); + CHECK(reader.Load(filename, &problem)); + FindLinearBooleanProblemSymmetries(problem, generators); +} + +TEST(FindLinearBooleanProblemSymmetriesTest, ProblemWithSymmetry1) { + std::string file = + "min: 1 x1 1 x2 ;\n" + "1 x1 1 x2 >= 1 ;\n" + "1 x1 1 x2 >= 1 ;\n"; + std::vector> generators; + FindSymmetries(file, &generators); + + // Note that the permutation is on the literals: + // xi maps to 2i and not(xi) maps to 2i + 1. + EXPECT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 2) (1 3)"); +} + +TEST(FindLinearBooleanProblemSymmetriesTest, ProblemWithSymmetry2) { + std::string file = + "min: 1 x1 1 x2 ;\n" + "-3 x1 -2 x2 >= -1 ;\n"; // This is simplified to both x1 and x2 false. + std::vector> generators; + FindSymmetries(file, &generators); + EXPECT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 2) (1 3)"); +} + +TEST(FindLinearBooleanProblemSymmetriesTest, ProblemWithSymmetry3) { + std::string file = + "min: 1 x1 1 x2 1 x3;\n" + " 1 x1 2 x2 3 x3 >= 2 ;\n" + " 1 x2 2 x3 3 x1 >= 2 ;\n" + " 1 x3 2 x1 3 x2 >= 2 ;\n"; + std::vector> generators; + FindSymmetries(file, &generators); + EXPECT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 4 2) (1 5 3)"); +} + +TEST(FindLinearBooleanProblemSymmetriesTest, ProblemWithSymmetry4) { + std::string file = + "min: 1 x1;\n" + " 1 x1 2 x2 >= 2 ;\n" + " 1 x1 -2 x3 >= 0 ;\n"; + std::vector> generators; + FindSymmetries(file, &generators); + EXPECT_EQ(generators.size(), 1); + + // x2 and not(x3) are equivalent. + EXPECT_EQ(generators[0]->DebugString(), "(2 5) (3 4)"); +} + +TEST(FindLinearBooleanProblemSymmetriesTest, ProblemWithoutSymmetry1) { + std::string file = + "min: 1 x1 2 x2 ;\n" + "1 x1 1 x2 >= 1 ;\n"; + std::vector> generators; + FindSymmetries(file, &generators); + EXPECT_EQ(generators.size(), 0); +} + +TEST(FindLinearBooleanProblemSymmetriesTest, ProblemWithoutSymmetry2) { + std::string file = + "min: 1 x1 1 x2 ;\n" + "1 x1 2 x2 >= 2 ;\n"; + std::vector> generators; + FindSymmetries(file, &generators); + EXPECT_EQ(generators.size(), 0); +} + +TEST(FindLinearBooleanProblemSymmetriesTest, PigeonHole) { + // This is the problem of putting 3 pigeons into 2 holes (UNSAT). + // x1: pigeon 1 is in hole 1 + // x2: pigeon 1 is in hole 2 + // x3: pigeon 2 is in hole 1 + // ... + std::string file = + "min: ;\n" + "1 x1 1 x2 >= 1 ;\n" // pigeon 1 should go into one hole + "1 x3 1 x4 >= 1 ;\n" // pigeon 2 should go into one hole + "1 x5 1 x6 >= 1 ;\n" // pigeon 3 should go into one hole + "-1 x1 -1 x3 -1 x5 >= -1 ;\n" // At most 1 pigeon in hole 1 + "-1 x2 -1 x4 -1 x6 >= -1 ;\n"; // At most 1 pigeon in hole 2 + std::vector> generators; + FindSymmetries(file, &generators); + + // The minimal support size is obtained with 3 generators: + // - The two holes are symmetric (x1, x2) (x3, x4) (x5, x6). + // - 2 generators for all the permutations of the 3 pigeons. + // + // But as of 2014-05, the symmetry finder isn't great at reducing the support + // size, but rather performs well at finding few generators, so it finds a + // solution with 2 generators. + EXPECT_EQ(generators.size(), 2); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/clause_test.cc b/ortools/sat/clause_test.cc new file mode 100644 index 0000000000..6c3b0ffd46 --- /dev/null +++ b/ortools/sat/clause_test.cc @@ -0,0 +1,411 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/clause.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/random/random.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + +template +auto LiteralsAre(Args... literals) { + return ::testing::ElementsAre(Literal(literals)...); +} + +template +auto UnorderedLiteralsAre(Args... literals) { + return ::testing::UnorderedElementsAre(Literal(literals)...); +} + +TEST(SatClauseTest, BasicAllocation) { + std::unique_ptr clause(SatClause::Create(Literals({+1, -2, +4}))); + EXPECT_EQ(3, clause->size()); + EXPECT_EQ(Literal(+1), clause->FirstLiteral()); + EXPECT_EQ(Literal(-2), clause->SecondLiteral()); +} + +struct TestSatClause { + bool is_learned : 1; + bool is_attached : 1; + unsigned int size : 30; + + // We test that Literal literals[0]; does not increase the size. +}; + +TEST(SatClauseTest, ClassSize) { + EXPECT_EQ(4, sizeof(TestSatClause)); + EXPECT_EQ(4, sizeof(SatClause)); +} + +BinaryClause MakeBinaryClause(int a, int b) { + return BinaryClause(Literal(a), Literal(b)); +} + +TEST(BinaryClauseManagerTest, BasicTest) { + BinaryClauseManager manager; + manager.Add(MakeBinaryClause(+2, +3)); + manager.Add(MakeBinaryClause(+3, +2)); // dup + manager.Add(MakeBinaryClause(+1, +4)); + manager.Add(MakeBinaryClause(+5, +4)); + manager.Add(MakeBinaryClause(+4, +1)); // dup + manager.Add(MakeBinaryClause(+2, +3)); // dup + EXPECT_EQ(3, manager.NumClauses()); + EXPECT_THAT(manager.newly_added(), + ElementsAre(MakeBinaryClause(+2, +3), MakeBinaryClause(+1, +4), + MakeBinaryClause(+5, +4))); + manager.ClearNewlyAdded(); + EXPECT_TRUE(manager.newly_added().empty()); + manager.Add(MakeBinaryClause(-1, +2)); + EXPECT_EQ(4, manager.NumClauses()); + EXPECT_THAT(manager.newly_added(), ElementsAre(MakeBinaryClause(-1, +2))); +} + +TEST(BinaryImplicationGraphTest, BasicUnsatSccTest) { + Model model; + model.GetOrCreate()->Resize(10); + auto* graph = model.GetOrCreate(); + graph->Resize(10); + // These are implications. + graph->AddBinaryClause(Literal(+1).Negated(), Literal(+2)); + graph->AddBinaryClause(Literal(+2).Negated(), Literal(+3)); + graph->AddBinaryClause(Literal(+3).Negated(), Literal(-1)); + graph->AddBinaryClause(Literal(-1).Negated(), Literal(+4)); + graph->AddBinaryClause(Literal(+4).Negated(), Literal(+1)); + EXPECT_FALSE(graph->DetectEquivalences()); +} + +TEST(BinaryImplicationGraphTest, DetectEquivalences) { + // We take a bunch of random permutations, equivalence classes will be cycles. + // We make sure the representative of x and not(x) are always negation of + // each other. + absl::BitGen random; + for (int num_passes = 0; num_passes < 10; ++num_passes) { + Model model; + auto* graph = model.GetOrCreate(); + + const int size = 1000; + model.GetOrCreate()->SetNumVariables(size); + std::vector permutation(size); + std::iota(permutation.begin(), permutation.end(), 0); + std::shuffle(permutation.begin(), permutation.end(), random); + for (int i = 0; i < size; ++i) { + // i => permutation[i]. + const int signed_value = i + 1; + const int signed_image = permutation[i] + 1; + graph->AddBinaryClause(Literal(signed_value).Negated(), + Literal(signed_image)); + } + + EXPECT_TRUE(graph->DetectEquivalences()); + int num_classes = 0; + for (int i = 0; i < size; ++i) { + const int signed_value = i + 1; + if (graph->RepresentativeOf(Literal(signed_value)) == + Literal(signed_value)) { + ++num_classes; + } else { + EXPECT_EQ(graph->RepresentativeOf(Literal(signed_value)).Negated(), + graph->RepresentativeOf(Literal(signed_value).Negated())); + } + } + + // It is unlikely that std::shuffle() produce the identity permutation, so + // this is not flaky and shows that there is some detection going on. + EXPECT_GT(num_classes, 0); + EXPECT_LT(num_classes, size); + } +} + +TEST(BinaryImplicationGraphTest, DetectEquivalencesWithAtMostOnes) { + Model model; + auto* graph = model.GetOrCreate(); + model.GetOrCreate()->SetNumVariables(10); + + // 2 and 4 are equivalent. They actually must be true in this setting. + EXPECT_TRUE(graph->AddAtMostOne(Literals({1, 2, 3}))); + EXPECT_TRUE(graph->AddAtMostOne(Literals({1, 4, 3}))); + graph->AddBinaryClause(Literal(1), Literal(4)); + graph->AddBinaryClause(Literal(3), Literal(2)); + + const auto& assignment = model.GetOrCreate()->Assignment(); + EXPECT_TRUE(graph->DetectEquivalences()); + EXPECT_TRUE(assignment.LiteralIsTrue(Literal(2))); + EXPECT_TRUE(assignment.LiteralIsTrue(Literal(4))); + EXPECT_TRUE(assignment.LiteralIsFalse(Literal(1))); + EXPECT_TRUE(assignment.LiteralIsFalse(Literal(3))); +} + +TEST(BinaryImplicationGraphTest, TransitiveReduction) { + Model model; + model.GetOrCreate()->SetNumVariables(10); + auto* graph = model.GetOrCreate(); + + for (BooleanVariable i(0); i < 10; ++i) { + for (BooleanVariable j(i + 1); j < 10; ++j) { + // i => j + graph->AddBinaryClause(Literal(i, false), Literal(j, true)); + } + + // These trivial clauses are filtered. + graph->AddBinaryClause(Literal(i, false), Literal(i, true)); + } + + EXPECT_EQ(graph->num_implications(), 10 * 9); + EXPECT_TRUE(graph->ComputeTransitiveReduction()); + EXPECT_EQ(graph->num_implications(), 9 * 2); +} + +// This basically just test our DCHECKs. +TEST(BinaryImplicationGraphTest, BasicRandomTransitiveReduction) { + Model model; + const int num_vars = 200; + model.GetOrCreate()->SetNumVariables(num_vars); + auto* graph = model.GetOrCreate(); + + // We add a lot of a => b (we might not have a DAG). + absl::BitGen random; + int num_added = 0; + for (int i = 0; i < 10'000; ++i) { + const BooleanVariable a(absl::Uniform(random, 0, num_vars)); + const BooleanVariable b(absl::Uniform(random, 0, num_vars)); + if (a == b) continue; + ++num_added; + graph->AddImplication(Literal(a, true), Literal(b, true)); + } + + EXPECT_EQ(graph->num_implications(), 2 * num_added); + EXPECT_TRUE(graph->ComputeTransitiveReduction()); + EXPECT_LT(graph->num_implications(), num_added); +} + +// We generate a random 2-SAT problem, and check that the propagation is +// unchanged whether or not the graph is reduced. +TEST(BinaryImplicationGraph, RandomTransitiveReduction) { + // These leads to a not trivial 2-SAT space with more than just all zero + // and all 1 as solution. + const int num_variables = 100; + const int num_constraints = 200; + + Model model1; + Model model2; + auto* sat1 = model1.GetOrCreate(); + auto* sat2 = model2.GetOrCreate(); + sat1->SetNumVariables(num_variables); + sat2->SetNumVariables(num_variables); + + absl::BitGen random; + for (int i = 0; i < num_constraints; ++i) { + // Because we only use positive literal, we are never UNSAT. + const Literal a = + Literal(BooleanVariable(absl::Uniform(random, 0, num_variables)), true); + const Literal b = + Literal(BooleanVariable(absl::Uniform(random, 0, num_variables)), true); + + // a => b. + sat1->AddBinaryClause(a.Negated(), b); + sat2->AddBinaryClause(a.Negated(), b); + sat2->AddBinaryClause(a.Negated(), b); + } + + auto* graph2 = model2.GetOrCreate(); + EXPECT_TRUE(graph2->ComputeTransitiveReduction()); + EXPECT_TRUE(sat1->Propagate()); + EXPECT_TRUE(sat2->Propagate()); + + absl::flat_hash_set propagated; + for (BooleanVariable var(0); var < num_variables; ++var) { + sat1->Backtrack(0); + sat2->Backtrack(0); + sat1->EnqueueDecisionIfNotConflicting(Literal(var, true)); + sat2->EnqueueDecisionIfNotConflicting(Literal(var, true)); + EXPECT_EQ(sat1->LiteralTrail().Index(), sat2->LiteralTrail().Index()); + propagated.clear(); + for (int i = 0; i < sat1->LiteralTrail().Index(); ++i) { + propagated.insert(sat1->LiteralTrail()[i].Index()); + } + for (int i = 0; i < sat2->LiteralTrail().Index(); ++i) { + EXPECT_TRUE(propagated.contains(sat2->LiteralTrail()[i].Index())); + } + } +} + +TEST(BinaryImplicationGraphTest, BasicCliqueDetection) { + std::vector> at_most_ones; + at_most_ones.push_back({Literal(+1), Literal(+2)}); + at_most_ones.push_back({Literal(+1), Literal(+3)}); + at_most_ones.push_back({Literal(+2), Literal(+3)}); + + Model model; + auto* graph = model.GetOrCreate(); + graph->Resize(10); + model.GetOrCreate()->Resize(10); + for (const std::vector& at_most_one : at_most_ones) { + EXPECT_TRUE(graph->AddAtMostOne(at_most_one)); + } + graph->TransformIntoMaxCliques(&at_most_ones); + EXPECT_THAT(at_most_ones[0], LiteralsAre(+1, +2, +3)); + EXPECT_TRUE(at_most_ones[1].empty()); + EXPECT_TRUE(at_most_ones[2].empty()); +} + +TEST(BinaryImplicationGraphTest, CliqueDetectionAndDuplicates) { + std::vector> at_most_ones; + at_most_ones.push_back({Literal(+1), Literal(+2)}); + at_most_ones.push_back({Literal(+2), Literal(+2)}); + + Model model; + auto* graph = model.GetOrCreate(); + model.GetOrCreate()->SetNumVariables(10); + for (const std::vector& at_most_one : at_most_ones) { + EXPECT_TRUE(graph->AddAtMostOne(at_most_one)); + } + + // Here we do not change the clique. + graph->TransformIntoMaxCliques(&at_most_ones); + EXPECT_THAT(at_most_ones, + ElementsAre(LiteralsAre(+1, +2), LiteralsAre(+2, +2))); + + // Clique detection call the SCC which will see that 2 must be false... + const auto& assignment = model.GetOrCreate()->Assignment(); + EXPECT_FALSE(assignment.LiteralIsAssigned(Literal(1))); + EXPECT_TRUE(assignment.LiteralIsFalse(Literal(2))); +} + +TEST(BinaryImplicationGraphTest, AddAtMostOneWithDuplicates) { + Model model; + auto* trail = model.GetOrCreate(); + auto* graph = model.GetOrCreate(); + trail->Resize(10); + graph->Resize(10); + EXPECT_TRUE(graph->AddAtMostOne(Literals({+1, +2, +3, +2}))); + EXPECT_TRUE(trail->Assignment().LiteralIsFalse(Literal(+2))); +} + +TEST(BinaryImplicationGraphTest, AddAtMostOneWithTriples) { + Model model; + auto* trail = model.GetOrCreate(); + auto* graph = model.GetOrCreate(); + trail->Resize(10); + graph->Resize(10); + EXPECT_TRUE(graph->AddAtMostOne(Literals({+1, +2, +3, +2, +2}))); + EXPECT_TRUE(trail->Assignment().LiteralIsFalse(Literal(+2))); +} + +TEST(BinaryImplicationGraphTest, AddAtMostOneCornerCase) { + Model model; + auto* trail = model.GetOrCreate(); + auto* graph = model.GetOrCreate(); + trail->Resize(10); + graph->Resize(10); + EXPECT_TRUE(graph->AddAtMostOne(Literals({+1, +2, +3, +2, -2}))); + EXPECT_TRUE(trail->Assignment().LiteralIsFalse(Literal(+1))); + EXPECT_TRUE(trail->Assignment().LiteralIsFalse(Literal(+2))); + EXPECT_TRUE(trail->Assignment().LiteralIsFalse(Literal(+3))); +} + +TEST(BinaryImplicationGraphTest, LargeAtMostOnePropagation) { + const int kNumVariables = 1e6; + + Model model; + auto* trail = model.GetOrCreate(); + auto* graph = model.GetOrCreate(); + trail->Resize(kNumVariables); + graph->Resize(kNumVariables); + + std::vector large_at_most_one; + for (int i = 0; i < kNumVariables; ++i) { + large_at_most_one.push_back(Literal(BooleanVariable(i), true)); + } + EXPECT_TRUE(graph->AddAtMostOne(large_at_most_one)); + + const Literal decision = Literal(BooleanVariable(42), true); + trail->SetDecisionLevel(1); + trail->EnqueueSearchDecision(Literal(decision)); + EXPECT_TRUE(graph->Propagate(trail)); + + const auto& assignment = trail->Assignment(); + for (int i = 0; i < kNumVariables; ++i) { + const Literal l = Literal(BooleanVariable(i), true); + if (i == 42) { + EXPECT_TRUE(assignment.LiteralIsTrue(l)); + } else { + EXPECT_TRUE(assignment.LiteralIsFalse(l)); + EXPECT_EQ(trail->Reason(l.Variable()), + absl::Span({decision.Negated()})); + } + } +} + +TEST(BinaryImplicationGraphTest, HeuristicAmoPartition) { + const int kNumVariables = 1e6; + + Model model; + auto* trail = model.GetOrCreate(); + model.GetOrCreate()->set_at_most_one_max_expansion_size(2); + auto* graph = model.GetOrCreate(); + trail->Resize(kNumVariables); + graph->Resize(kNumVariables); + + EXPECT_TRUE(graph->AddAtMostOne(Literals({+1, +2, +3, +4}))); + EXPECT_TRUE(graph->AddAtMostOne(Literals({+4, +5, +6, +7}))); + + std::vector literals = Literals({+1, +2, +5, +6, +7}); + EXPECT_THAT(graph->HeuristicAmoPartition(&literals), + ElementsAre(UnorderedLiteralsAre(+5, +6, +7), + UnorderedLiteralsAre(+1, +2))); + + EXPECT_TRUE(graph->AddAtMostOne(Literals({+1, +2, +6, +7}))); + EXPECT_THAT(graph->HeuristicAmoPartition(&literals), + ElementsAre(LiteralsAre(+6, +7, +2, +1))); +} + +TEST(BinaryImplicationGraphTest, RandomImpliedLiteral) { + Model model; + auto* trail = model.GetOrCreate(); + auto* graph = model.GetOrCreate(); + trail->Resize(100); + graph->Resize(100); + + graph->AddImplication(Literal(+1), Literal(+6)); + graph->AddImplication(Literal(+1), Literal(+7)); + EXPECT_TRUE(graph->AddAtMostOne(Literals({+1, +2, +4, +5}))); + + absl::flat_hash_set seen; + for (int i = 0; i < 100; ++i) { + seen.insert(graph->RandomImpliedLiteral(Literal(+1))); + } + EXPECT_THAT(seen, UnorderedLiteralsAre(-2, -4, -5, +6, +7)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/constraint_violation_test.cc b/ortools/sat/constraint_violation_test.cc new file mode 100644 index 0000000000..6b9e468c45 --- /dev/null +++ b/ortools/sat/constraint_violation_test.cc @@ -0,0 +1,729 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/constraint_violation.h" + +#include +#include +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/dump_vars.h" +#include "ortools/base/gmock.h" +#include "ortools/base/logging.h" +#include "ortools/base/parse_test_proto.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::google::protobuf::contrib::parse_proto::ParseTestProto; +using ::testing::ElementsAre; + +TEST(LinearEvaluatorTest, TestAPI) { + LinearIncrementalEvaluator evaluator; + const int c1 = evaluator.NewConstraint({0, 10}); + EXPECT_EQ(c1, 0); + const int c2 = evaluator.NewConstraint({1, 20}); + EXPECT_EQ(c2, 1); + evaluator.AddTerm(0, 0, 2); + EXPECT_TRUE(evaluator.VarIsConsistent(0)); + EXPECT_TRUE(evaluator.VarIsConsistent(1)); + evaluator.AddTerm(0, 0, 3); + EXPECT_TRUE(evaluator.VarIsConsistent(0)); + evaluator.AddTerm(1, 0, 1); + EXPECT_TRUE(evaluator.VarIsConsistent(0)); + if (!DEBUG_MODE) { + evaluator.AddTerm(0, 0, 5); + EXPECT_FALSE(evaluator.VarIsConsistent(0)); + } +} + +TEST(LinearEvaluatorTest, IncrementalScoreComputationForEnforcement) { + LinearIncrementalEvaluator evaluator; + const int c = evaluator.NewConstraint({1, 1}); + evaluator.AddEnforcementLiteral(c, PositiveRef(0)); + evaluator.AddEnforcementLiteral(c, NegatedRef(1)); + evaluator.AddEnforcementLiteral(c, PositiveRef(2)); + evaluator.PrecomputeCompactView({1, 1, 1}); // All Booleans. + + std::vector weights{1.0}; + std::vector solution{0, 0, 0}; + std::vector jump_deltas{1, 1, 1}; + std::vector jump_scores(3, 0.0); + std::vector modified_constraints; + + // For all possible solution, we try all possible move. + for (int sol = 0; sol < 8; ++sol) { + for (int move = 0; move < 3; ++move) { + // Initialize base solution. + for (int var = 0; var < 3; ++var) { + solution[var] = ((sol >> var) & 1); + jump_deltas[var] = (1 ^ solution[var]) - solution[var]; + } + evaluator.ComputeInitialActivities(solution); + for (int var = 0; var < 3; ++var) { + jump_scores[var] = + evaluator.WeightedViolationDelta(weights, var, jump_deltas[var]); + } + + // Perform move. + evaluator.UpdateVariableAndScores( + move, jump_deltas[move], weights, jump_deltas, + absl::MakeSpan(jump_scores), &modified_constraints); + + // We never update the score of the flipped variable. + solution[move] = 1 ^ solution[move]; + jump_deltas[move] = (1 ^ solution[move]) - solution[move]; + jump_scores[move] = + evaluator.WeightedViolationDelta(weights, move, jump_deltas[move]); + + // Test that the scores are correctly updated. + for (int test = 0; test < 3; ++test) { + ASSERT_EQ(jump_scores[test], evaluator.WeightedViolationDelta( + weights, test, jump_deltas[test])) + << DUMP_VARS(solution) << "\n" + << DUMP_VARS(move) << "\n" + << DUMP_VARS(test); + } + } + } +} + +TEST(LinearEvaluatorTest, EmptyConstraintDoNotCrash) { + LinearIncrementalEvaluator evaluator; + evaluator.NewConstraint({1, 1}); + evaluator.NewConstraint({1, 1}); + evaluator.NewConstraint({1, 1}); + + std::vector solution{0, 0, 0}; + std::vector jump_deltas{1, 1, 1}; + std::vector jump_scores(3, 0.0); + + evaluator.PrecomputeCompactView({10, 10, 10}); + evaluator.ComputeInitialActivities(solution); + evaluator.UpdateScoreOnWeightUpdate(1, jump_deltas, + absl::MakeSpan(jump_scores)); +} + +TEST(ConstraintViolationTest, BasicExactlyOneExampleNonViolated) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { exactly_one { literals: [ 0, 1, 2, 3 ] } } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 0, 0, 1}); + EXPECT_EQ(0, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, BasicExactlyOneExampleViolated) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { exactly_one { literals: [ 0, 1, 2, 3 ] } } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 0, 1, 1}); + EXPECT_EQ(1, ls.SumOfViolations()); + EXPECT_THAT(ls.ViolatedConstraints(), ElementsAre(0)); + EXPECT_EQ(ls.NumViolatedConstraintsForVarIgnoringObjective(0), 1); + EXPECT_EQ(ls.NumViolatedConstraintsForVarIgnoringObjective(1), 1); + EXPECT_EQ(ls.NumViolatedConstraintsForVarIgnoringObjective(2), 1); + EXPECT_EQ(ls.NumViolatedConstraintsForVarIgnoringObjective(3), 1); +} + +TEST(ConstraintViolationTest, BasicBoolOrViolated) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { bool_or { literals: [ 0, -2, 2, -4 ] } } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 1, 0, 1}); + EXPECT_EQ(1, ls.SumOfViolations()); + ls.ComputeAllViolations({0, 0, 0, 1}); + EXPECT_EQ(0, ls.SumOfViolations()); + ls.ComputeAllViolations({0, 1, 0, 0}); + EXPECT_EQ(0, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, BasicLinearExample) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 4 ] } + variables { domain: [ 0, 5 ] } + constraints { + linear { + vars: [ 0, 1 ], + coeffs: [ 2, 3 ], + domain: [ 1, 4 ], + } + } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 0}); + EXPECT_EQ(1, ls.SumOfViolations()); + ls.ComputeAllViolations({2, 0}); + EXPECT_EQ(0, ls.SumOfViolations()); + ls.ComputeAllViolations({2, 3}); + EXPECT_EQ(9, ls.SumOfViolations()); + EXPECT_EQ(ls.NumViolatedConstraintsForVarIgnoringObjective(0), 1); + EXPECT_EQ(ls.NumViolatedConstraintsForVarIgnoringObjective(1), 1); +} + +TEST(ConstraintViolationTest, BasicObjectiveExampleWithChange) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + objective { + vars: [ 0, 1, 2, 3 ], + coeffs: [ 2, 3, 4, 5 ], + } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 0, 0, 1}); + EXPECT_EQ(0, ls.SumOfViolations()); + EXPECT_EQ(ls.NumViolatedConstraintsForVarIgnoringObjective(0), 0); + ls.ReduceObjectiveBounds(0, 3); + EXPECT_EQ(2, ls.SumOfViolations()); + EXPECT_THAT(ls.ViolatedConstraints(), ElementsAre(0)); + EXPECT_EQ(ls.NumViolatedConstraintsForVarIgnoringObjective(0), 0); +} + +TEST(ConstraintViolationTest, BasicBoolXorExample) { + const ConstraintProto ct_proto = + ParseTestProto(R"pb(bool_xor { literals: [ 0, -2, 2 ] })pb"); + CompiledBoolXorConstraint ct(ct_proto); + EXPECT_EQ(0, ct.ComputeViolation({1, 1, 0})); + EXPECT_EQ(0, ct.ComputeViolation({0, 0, 0})); + EXPECT_EQ(1, ct.ComputeViolation({1, 0, 0})); +} + +TEST(ConstraintViolationTest, BasicLinMaxExampleNoViolation) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + lin_max { + target { vars: 0 coeffs: 2 } + exprs { vars: 1 coeffs: 1 offset: 1 } + exprs { vars: 2 coeffs: 4 offset: 1 } + exprs { offset: 1 } + } + } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({1, 1, 0}); + EXPECT_EQ(0, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, BasicLinMaxExampleExcessViolation) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + lin_max { + target { vars: 0 coeffs: 2 } + exprs { vars: 1 coeffs: 1 offset: 1 } + exprs { vars: 2 coeffs: 4 offset: 1 } + exprs { offset: 1 } + } + } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 0, 0}); + EXPECT_EQ(3, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, BasicLinMaxExampleMissingViolation) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + lin_max { + target { vars: 0 coeffs: 2 } + exprs { vars: 1 coeffs: 1 offset: 1 } + exprs { vars: 2 coeffs: 4 offset: 1 } + exprs { offset: 1 } + } + } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({1, 0, 0}); + EXPECT_EQ(1, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, BasicLinMaxExampleNegativeCoeffs) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 40 ] } + variables { domain: [ 0, 40 ] } + variables { domain: [ 0, 40 ] } + constraints { + lin_max { + target { vars: 2 coeffs: -1 } + exprs { vars: 0 coeffs: -1 offset: -1 } + exprs { vars: 1 coeffs: -1 offset: -1 } + } + } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({33, 33, 33}); + EXPECT_EQ(1, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, BasicIntProdExample) { + const ConstraintProto ct_proto = ParseTestProto(R"pb( + int_prod { + target { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 offset: 1 } + exprs { vars: 2 coeffs: 2 } + } + )pb"); + + CompiledIntProdConstraint ct(ct_proto); + EXPECT_EQ(1, ct.ComputeViolation({1, 0, 0})); + EXPECT_EQ(3, ct.ComputeViolation({1, 0, 2})); + EXPECT_EQ(2, ct.ComputeViolation({6, 0, 2})); +} + +TEST(ConstraintViolationTest, BasicIntDivExample) { + const ConstraintProto ct_proto = ParseTestProto(R"pb( + int_div { + target { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 offset: 1 } + exprs { vars: 2 coeffs: 1 } + } + )pb"); + CompiledIntDivConstraint ct(ct_proto); + EXPECT_EQ(1, ct.ComputeViolation({0, 1, 2})); + EXPECT_EQ(3, ct.ComputeViolation({0, 6, 2})); +} + +TEST(ConstraintViolationTest, BasicIntModExample) { + const ConstraintProto ct_proto = ParseTestProto(R"pb( + int_mod { + target { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + } + )pb"); + CompiledIntModConstraint ct(ct_proto); + EXPECT_EQ(1, ct.ComputeViolation({1, 2, 3})); + EXPECT_EQ(0, ct.ComputeViolation({1, 7, 3})); + // Wrap around. + EXPECT_EQ(2, ct.ComputeViolation({1, 5, 6})); + EXPECT_EQ(2, ct.ComputeViolation({5, 1, 6})); + // Across 0. + EXPECT_EQ(22, ct.ComputeViolation({18, -4, 6})); + EXPECT_EQ(22, ct.ComputeViolation({-18, 4, 6})); +} + +TEST(ConstraintViolationTest, BasicAllDiffExample) { + const ConstraintProto ct_proto = ParseTestProto(R"pb( + all_diff { + exprs { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + exprs { vars: 3 coeffs: 1 } + } + )pb"); + CompiledAllDiffConstraint ct(ct_proto); + EXPECT_EQ(3, ct.ComputeViolation({2, 1, 2, 2})); + EXPECT_EQ(6, ct.ComputeViolation({2, 2, 2, 2})); + EXPECT_EQ(1, ct.ComputeViolation({1, 2, 3, 1})); + EXPECT_EQ(2, ct.ComputeViolation({1, 2, 2, 1})); +} + +TEST(ConstraintViolationTest, BasicNoOverlapExample) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + interval { + start: { vars: 0 coeffs: 1 }, + size: { offset: 4 }, + end: { vars: 0 coeffs: 1 offset: 4 } + } + } + constraints { + interval { + start: { vars: 1 coeffs: 1 }, + size: { offset: 4 }, + end: { vars: 1 coeffs: 1 offset: 4 } + } + } + constraints { + interval { + start: { vars: 2 coeffs: 1 }, + size: { offset: 4 }, + end: { vars: 2 coeffs: 1 offset: 4 } + } + } + constraints { no_overlap { intervals: [ 0, 1, 2 ] } } + )pb"); + + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 4, 8, 1}); + EXPECT_EQ(0, ls.SumOfViolations()); + + ls.ComputeAllViolations({0, 2, 4, 1}); + EXPECT_EQ(4, ls.SumOfViolations()); + + ls.ComputeAllViolations({0, 0, 0, 1}); + EXPECT_EQ(12, ls.SumOfViolations()); + + ls.ComputeAllViolations({1, 2, 3, 1}); + EXPECT_EQ(8, ls.SumOfViolations()); + + ls.ComputeAllViolations({1, 2, 3, 0}); + EXPECT_EQ(3, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, TwoIntervalsNoOverlapExample) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 2 + interval { + start: { vars: 0 coeffs: 1 }, + size: { offset: 4 }, + end: { vars: 0 coeffs: 1 offset: 4 } + } + } + constraints { + interval { + start: { vars: 1 coeffs: 1 }, + size: { offset: 4 }, + end: { vars: 1 coeffs: 1 offset: 4 } + } + } + constraints { no_overlap { intervals: [ 0, 1 ] } } + )pb"); + + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 4, 1}); + EXPECT_EQ(0, ls.SumOfViolations()); + + ls.ComputeAllViolations({0, 2, 1}); + EXPECT_EQ(2, ls.SumOfViolations()); + + ls.ComputeAllViolations({0, 0, 1}); + EXPECT_EQ(4, ls.SumOfViolations()); + + ls.ComputeAllViolations({1, 2, 1}); + EXPECT_EQ(3, ls.SumOfViolations()); + + ls.ComputeAllViolations({1, 2, 0}); + EXPECT_EQ(0, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, BasicCumulativeExample) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 5 ] } + variables { domain: [ 2, 4 ] } + constraints { + enforcement_literal: 3 + interval { + start: { vars: 0 coeffs: 1 }, + size: { offset: 4 }, + end: { vars: 0 coeffs: 1 offset: 4 } + } + } + constraints { + interval { + start: { vars: 1 coeffs: 1 }, + size: { offset: 4 }, + end: { vars: 1 coeffs: 1 offset: 4 } + } + } + constraints { + interval { + start: { vars: 2 coeffs: 1 }, + size: { offset: 4 }, + end: { vars: 2 coeffs: 1 offset: 4 } + } + } + constraints { + cumulative { + intervals: [ 0, 1, 2 ] + demands: { offset: 2 } + demands: { offset: 2 } + demands: { vars: 4, coeffs: 1 } + capacity: { vars: 5 coeffs: 1 } + } + } + )pb"); + + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 4, 8, 1, 2, 2}); + EXPECT_EQ(0, ls.SumOfViolations()); + + ls.ComputeAllViolations({0, 2, 4, 1, 2, 2}); + EXPECT_EQ(8, ls.SumOfViolations()); + + ls.ComputeAllViolations({0, 0, 0, 1, 1, 3}); + EXPECT_EQ(8, ls.SumOfViolations()); + + ls.ComputeAllViolations({1, 2, 3, 1, 1, 4}); + EXPECT_EQ(2, ls.SumOfViolations()); + + ls.ComputeAllViolations({1, 2, 3, 0, 1, 4}); + EXPECT_EQ(0, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, EmptyNoOverlap) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + constraints { no_overlap {} } + )pb"); + + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 4, 8}); + EXPECT_EQ(0, ls.SumOfViolations()); +} + +TEST(ConstraintViolationTest, WeightedViolationAndDelta) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 4 ] } + variables { domain: [ 0, 5 ] } + constraints { + linear { + vars: [ 0, 1 ], + coeffs: [ 2, 3 ], + domain: [ 1, 4 ], + } + } + constraints { + linear { + vars: [ 0, 1 ], + coeffs: [ 7, 8 ], + domain: [ 1, 20 ], + } + } + )pb"); + + SatParameters params; + LsEvaluator ls(model, params); + + std::vector solution{0, 0}; + std::vector weight{0.0, 0.0}; + for (int i = 0; i < 10; ++i) { + solution[0] = i; + for (int j = 1; j < 10; ++j) { + solution[1] = 0; + + ls.ComputeAllViolations(solution); + const double delta = + ls.WeightedViolationDelta(/*linear_only=*/false, weight, 1, j, + absl::MakeSpan(solution)); // 0 -> j + const double expected = ls.WeightedViolation(weight) + delta; + + solution[1] = j; + ls.ComputeAllViolations(solution); + EXPECT_EQ(expected, ls.WeightedViolation(weight)); + } + } +} + +TEST(ConstraintViolationTest, Breakpoints) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 4 ] } + variables { domain: [ 0, 5 ] } + constraints { + linear { + vars: [ 0, 1 ], + coeffs: [ 2, 3 ], + domain: [ 1, 4 ], + } + } + constraints { + linear { + vars: [ 0, 1 ], + coeffs: [ 7, 8 ], + domain: [ 1, 20 ], + } + } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 0}); + + // We don't want the same value as zero, so we should include both values + // around it to be sure we don't miss the minimum. + // + // breakpoints for the first constraint should be at 0,1 and 2. + // breakpoints for seconds constraints should be at 0,1 and 2,3. + EXPECT_THAT( + ls.MutableLinearEvaluator()->SlopeBreakpoints(0, 0, Domain(-5, 8)), + ::testing::ElementsAre(-5, 0, 1, 2, 3, 8)); +} + +TEST(ConstraintViolationTest, BasicCircuit) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } # 0->1 + variables { domain: [ 0, 1 ] } # 1->2 + variables { domain: [ 0, 1 ] } # 1->0 + variables { domain: [ 0, 1 ] } # 2->0 + variables { domain: [ 0, 1 ] } # 2->2 + variables { domain: [ 0, 1 ] } # 0->2 + variables { domain: [ 0, 1 ] } # 0->0 + variables { domain: [ 0, 1 ] } # 2->1 + constraints { + circuit { + tails: [ 0, 1, 1, 2, 2, 0, 0, 2 ] + heads: [ 1, 2, 0, 0, 2, 2, 0, 1 ] + literals: [ 0, 1, 2, 3, 4, 5, 6, 7 ] + } + } + )pb"); + + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 0, 0, 0, 0, 0, 0, 0}); + EXPECT_GE(ls.SumOfViolations(), 1); + ls.ComputeAllViolations({1, 0, 1, 0, 0, 0, 0, 0}); + EXPECT_GE(ls.SumOfViolations(), 1); + ls.ComputeAllViolations({1, 0, 1, 0, 1, 0, 0, 0}); + EXPECT_EQ(ls.SumOfViolations(), 0); + ls.ComputeAllViolations({1, 0, 1, 0, 0, 1, 0, 0}); + EXPECT_GE(ls.SumOfViolations(), 1); + ls.ComputeAllViolations({1, 0, 1, 1, 0, 1, 0, 0}); + EXPECT_GE(ls.SumOfViolations(), 1); + ls.ComputeAllViolations({1, 1, 0, 1, 0, 0, 0, 0}); + EXPECT_EQ(ls.SumOfViolations(), 0); + ls.ComputeAllViolations({0, 1, 0, 0, 0, 0, 1, 1}); + EXPECT_EQ(ls.SumOfViolations(), 0); +} + +TEST(ConstraintViolationTest, BasicMultiCircuit) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } # 0->1 + variables { domain: [ 0, 1 ] } # 1->2 + variables { domain: [ 0, 1 ] } # 1->0 + variables { domain: [ 0, 1 ] } # 2->0 + variables { domain: [ 0, 1 ] } # 2->2 + variables { domain: [ 0, 1 ] } # 0->2 + variables { domain: [ 0, 1 ] } # 2->1 + constraints { + routes { + tails: [ 0, 1, 1, 2, 2, 0, 2 ] + heads: [ 1, 2, 0, 0, 2, 2, 1 ] + literals: [ 0, 1, 2, 3, 4, 5, 6 ] + } + } + )pb"); + + SatParameters params; + LsEvaluator ls(model, params); + ls.ComputeAllViolations({0, 0, 0, 0, 0, 0, 0}); + EXPECT_GE(ls.SumOfViolations(), 1) << "arcs: None"; + ls.ComputeAllViolations({1, 0, 1, 0, 0, 0, 0}); + EXPECT_GE(ls.SumOfViolations(), 1) << "arcs: 0->1;1->0"; + ls.ComputeAllViolations({1, 0, 1, 0, 0, 1, 0}); + EXPECT_GE(ls.SumOfViolations(), 1) << "arcs: 0->1;1->0;0->2"; + ls.ComputeAllViolations({1, 0, 1, 1, 0, 0, 0}); + EXPECT_GE(ls.SumOfViolations(), 1) << "arcs: 0->1;1->0;2->0"; + ls.ComputeAllViolations({1, 0, 1, 0, 1, 0, 0}); + EXPECT_EQ(ls.SumOfViolations(), 0) << "arcs: 0->1;1->0;2->2"; + ls.ComputeAllViolations({1, 0, 1, 1, 0, 1, 0}); + EXPECT_EQ(ls.SumOfViolations(), 0) << "arcs: 0->1;1->0;0->2;2->0"; + ls.ComputeAllViolations({1, 0, 1, 1, 0, 1, 0}); + EXPECT_EQ(ls.SumOfViolations(), 0) << "arcs: 0->1;1->0;0->2;2->0"; + ls.ComputeAllViolations({0, 1, 0, 0, 0, 0, 1}); + EXPECT_GE(ls.SumOfViolations(), 1) << "arcs: 1->2;2->1"; +} + +TEST(ConstraintViolationTest, LastUpdateViolationChanges) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 4 ] } + variables { domain: [ 0, 5 ] } + variables { domain: [ 0, 20 ] } + constraints { + linear { + vars: [ 0, 1 ], + coeffs: [ 2, 3 ], + domain: [ 1, 4 ], + } + } + constraints { + int_prod { + target { vars: 2 coeffs: 1 } + exprs { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + } + } + )pb"); + SatParameters params; + LsEvaluator ls(model, params); + std::vector unused_jump_scores = {0.0, 0.0, 0.0}; + + std::vector solution = {2, 1, 3}; + ls.ComputeAllViolations(solution); + + solution[0] = 3; + ls.UpdateLinearScores(0, 2, 3, /*weights=*/{1.0, 1.0}, + /*jump_deltas=*/{-2, -1, -3}, + absl::MakeSpan(unused_jump_scores)); + ls.UpdateNonLinearViolations(0, 2, solution); + EXPECT_THAT(ls.last_update_violation_changes(), ElementsAre(0, 1)); + + solution[2] = 2; + ls.UpdateLinearScores(2, 3, 2, /*weights=*/{1.0, 1.0}, + /*jump_deltas=*/{-2, -1, 3}, + absl::MakeSpan(unused_jump_scores)); + ls.UpdateNonLinearViolations(2, 3, solution); + EXPECT_THAT(ls.last_update_violation_changes(), ElementsAre(1)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index ec93570e43..9c4cc52875 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -36,6 +36,7 @@ #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" +#include "ortools/sat/cp_model_table.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/presolve_context.h" #include "ortools/sat/sat_parameters.pb.h" diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 4080d435f1..3a1744f26b 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -60,6 +60,7 @@ #include "ortools/sat/cp_model_expand.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/cp_model_symmetries.h" +#include "ortools/sat/cp_model_table.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/diffn_util.h" #include "ortools/sat/diophantine.h" diff --git a/ortools/sat/cp_model_presolve_random_test.cc b/ortools/sat/cp_model_presolve_random_test.cc new file mode 100644 index 0000000000..dca4a12157 --- /dev/null +++ b/ortools/sat/cp_model_presolve_random_test.cc @@ -0,0 +1,329 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file tests the various presolves by asserting that the result of a +// randomly generated linear integer program is the same with or without the +// presolve step. The linear programs are generated in a way that tries to cover +// all the corner cases of the preprocessor (for the linear part). + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/log/check.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "gtest/gtest.h" +#include "ortools/base/helpers.h" +#include "ortools/base/logging.h" +#include "ortools/base/options.h" +#include "ortools/base/path.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/util/sorted_interval_list.h" + +ABSL_FLAG(std::string, dump_dir, "", + "If non-empty, a dir where all the models are dumped."); + +namespace operations_research { +namespace sat { +namespace { + +int64_t GetRandomNonZero(int max_magnitude, absl::BitGen* random) { + if (absl::Bernoulli(*random, 0.5)) { + return absl::Uniform(*random, -max_magnitude, -1); + } + return absl::Uniform(*random, 1, max_magnitude); +} + +int64_t GetRandomNonZeroAndNonInvertible(int max_magnitude, + absl::BitGen* random) { + if (absl::Bernoulli(*random, 0.5)) { + return absl::Uniform(*random, -max_magnitude, -1); + } + return absl::Uniform(*random, 2, max_magnitude); +} + +// Generate an initial linear program that will be extended later with new +// variables and constraints that the preprocessors should be able to remove. +CpModelProto GenerateRandomBaseProblem(absl::BitGen* random) { + CpModelProto result; + result.set_name("Random IP"); + const int num_variables = absl::Uniform(*random, 1, 20); + const int num_constraints = absl::Uniform(*random, 1, 20); + + for (int i = 0; i < num_variables; ++i) { + sat::IntegerVariableProto* var = result.add_variables(); + var->add_domain(absl::Uniform(*random, -10, 10)); + var->add_domain(absl::Uniform(*random, var->domain(0), 10)); + } + for (int i = 0; i < num_constraints; ++i) { + auto* ct = result.add_constraints()->mutable_linear(); + ct->add_domain(absl::Uniform(*random, -100, 100)); + ct->add_domain(absl::Uniform(*random, ct->domain(0), 100)); + for (int v = 0; v < num_variables; ++v) { + if (absl::Bernoulli(*random, 0.2)) { // Sparser. + ct->add_vars(v); + ct->add_coeffs(GetRandomNonZero(10, random)); + } + } + } + + std::vector all_variables(num_variables); + std::iota(begin(all_variables), end(all_variables), 0); + std::shuffle(begin(all_variables), end(all_variables), *random); + for (const int v : all_variables) { + if (absl::Bernoulli(*random, 0.5)) { + result.mutable_objective()->add_vars(v); + result.mutable_objective()->add_coeffs( + absl::Uniform(*random, -100, 100)); + } + } + result.mutable_objective()->set_offset( + absl::Uniform(*random, -100, 100)); + result.mutable_objective()->set_scaling_factor( + absl::Uniform(*random, -100, 100)); + + return result; +} + +// Adds a row to the given problem which is a duplicate (with a random +// proportionality factor) of a random row. +void AddRandomDuplicateRow(absl::BitGen* random, CpModelProto* proto) { + const int64_t factor = GetRandomNonZero(10, random); + const LinearConstraintProto& source = + proto + ->constraints(absl::Uniform(*random, 0, + proto->constraints().size() - 1)) + .linear(); + auto* ct = proto->add_constraints()->mutable_linear(); + FillDomainInProto(ReadDomainFromProto(source).MultiplicationBy(factor), ct); + for (int i = 0; i < source.vars().size(); ++i) { + ct->add_vars(source.vars(i)); + ct->add_coeffs(source.coeffs(i) * factor); + } +} + +// Adds a column to the given problem which is a duplicate (with a random +// proportionality factor) of a random column. +// +// Note(user): This is not super efficient as we rescan the whole problem for +// this. +void AddRandomDuplicateColumn(absl::BitGen* random, CpModelProto* proto) { + const int new_var = proto->variables().size(); + const int source_var = absl::Uniform(*random, 0, new_var - 1); + + sat::IntegerVariableProto* var = proto->add_variables(); + var->add_domain(absl::Uniform(*random, -10, 10)); + var->add_domain(absl::Uniform(*random, var->domain(0), 10)); + + const int64_t factor = GetRandomNonZero(10, random); + for (int c = 0; c < proto->constraints().size(); ++c) { + LinearConstraintProto* linear = + proto->mutable_constraints(c)->mutable_linear(); + for (int i = 0; i < linear->vars().size(); ++i) { + if (linear->vars(i) == source_var) { + linear->add_vars(new_var); + linear->add_coeffs(linear->coeffs(i) * factor); + break; + } + } + } +} + +// Adds a random x == coeff * y + offset affine relation to the model. +void AddRandomAffineRelation(absl::BitGen* random, CpModelProto* proto) { + const int num_vars = proto->variables().size(); + const int a = absl::Uniform(*random, 0, num_vars - 1); + const int b = absl::Uniform(*random, 0, num_vars - 1); + if (a == b) return; + LinearConstraintProto* linear = proto->add_constraints()->mutable_linear(); + const int64_t offset = absl::Uniform(*random, -5, 5); + linear->add_domain(offset); + linear->add_domain(offset); + linear->add_vars(a); + linear->add_coeffs(1); + linear->add_vars(b); + linear->add_coeffs(GetRandomNonZero(5, random)); +} + +// Calls GenerateRandomBaseProblem() and extends the problem in various random +// ways. +CpModelProto GenerateRandomProblem(const std::string& env_name) { + absl::BitGen random; + CpModelProto result = GenerateRandomBaseProblem(&random); + for (int i = 0; i < absl::Uniform(random, 0, 10); ++i) { + switch (absl::Uniform(random, 0, 2)) { + case 0: + AddRandomDuplicateRow(&random, &result); + break; + case 1: + AddRandomDuplicateColumn(&random, &result); + break; + case 2: + AddRandomAffineRelation(&random, &result); + break; + } + } + return result; +} + +// Parameterized test to test different random lp. +class RandomPreprocessorTest : public ::testing::TestWithParam { + protected: + std::string GetSeedEnvName() { + return absl::StrFormat("TestCase%d", GetParam()); + } +}; + +TEST_P(RandomPreprocessorTest, SolveWithAndWithoutPresolve) { + const CpModelProto model_proto = GenerateRandomProblem(GetSeedEnvName()); + if (!absl::GetFlag(FLAGS_dump_dir).empty()) { + const std::string name = + file::JoinPath(absl::GetFlag(FLAGS_dump_dir), + absl::StrCat(GetSeedEnvName(), ".pb.txt")); + LOG(INFO) << "Dumping model to '" << name << "'."; + CHECK_OK(file::SetTextProto(name, model_proto, file::Defaults())); + } + + SatParameters params; + params.set_cp_model_presolve(true); + const CpSolverResponse response_with = + SolveWithParameters(model_proto, params); + params.set_cp_model_presolve(false); + const CpSolverResponse response_without = + SolveWithParameters(model_proto, params); + EXPECT_EQ(response_with.status(), response_without.status()); + EXPECT_NEAR(response_with.objective_value(), + // 1e-10 yields flakiness (<0.1%): see gpaste/5821350335741952. + response_without.objective_value(), 1e-9); +} + +// Note that because we just generate linear model, this doesn't exercise all +// the expansion code which is likely to lose the hint. Still it is a start. +TEST_P(RandomPreprocessorTest, TestHintSurvivePresolve) { + CpModelProto model_proto = GenerateRandomProblem(GetSeedEnvName()); + + // We only deal with feasible problem. Note that many are just INFEASIBLE, so + // maybe we should generate something smarter. + const CpSolverResponse first_solve = Solve(model_proto); + if (first_solve.status() != CpSolverStatus::OPTIMAL && + first_solve.status() != CpSolverStatus::FEASIBLE) { + return; + } + + const int num_vars = model_proto.variables().size(); + for (int i = 0; i < num_vars; ++i) { + model_proto.mutable_solution_hint()->add_vars(i); + model_proto.mutable_solution_hint()->add_values(first_solve.solution(i)); + } + + // We just check that the hint is correct. + SatParameters params; + params.set_debug_crash_on_bad_hint(true); + params.set_stop_after_first_solution(true); + const CpSolverResponse with_hint = SolveWithParameters(model_proto, params); + + // Lets also test that the tightened domains contains the hint. + model_proto.clear_objective(); + model_proto.clear_solution_hint(); + SatParameters tighten_params; + tighten_params.set_keep_all_feasible_solutions_in_presolve(true); + tighten_params.set_fill_tightened_domains_in_response(true); + const CpSolverResponse with_tighten = + SolveWithParameters(model_proto, tighten_params); + EXPECT_EQ(with_tighten.tightened_variables().size(), num_vars); + for (int i = 0; i < num_vars; i++) { + EXPECT_TRUE(ReadDomainFromProto(with_tighten.tightened_variables(i)) + .Contains(first_solve.solution(i))); + } +} + +TEST_P(RandomPreprocessorTest, SolveDiophantineWithAndWithoutPresolve) { + absl::BitGen random; + CpModelProto model_proto; + model_proto.set_name("Random Diophantine"); + const int num_variables = absl::Uniform(random, 1, 6); + for (int i = 0; i < num_variables; ++i) { + IntegerVariableProto* var = model_proto.add_variables(); + int64_t min = absl::Uniform(random, -10, 10); + int64_t max = absl::Uniform(random, -10, 10); + if (max < min) std::swap(min, max); + var->add_domain(min); + var->add_domain(max); + } + const bool has_indicator = absl::Bernoulli(random, 0.5); + if (has_indicator) { + IntegerVariableProto* var = model_proto.add_variables(); + var->add_domain(0); + var->add_domain(1); + } + + auto* constraint = model_proto.add_constraints(); + if (has_indicator) constraint->add_enforcement_literal(num_variables); + auto* lin = constraint->mutable_linear(); + lin->add_domain(absl::Uniform(random, -10, 10)); + lin->add_domain(lin->domain(0)); + for (int v = 0; v < num_variables; ++v) { + lin->add_vars(v); + lin->add_coeffs(GetRandomNonZeroAndNonInvertible(10, &random)); + } + + model_proto.mutable_objective()->set_scaling_factor(1); + for (int v = 0; v < num_variables; ++v) { + if (absl::Bernoulli(random, 0.5)) { + model_proto.mutable_objective()->add_vars(v); + model_proto.mutable_objective()->add_coeffs( + absl::Uniform(random, -10, 10)); + } + } + if (has_indicator) { + // Indicator should be deactivated only if the equation is unfeasible. + model_proto.mutable_objective()->add_vars(num_variables); + model_proto.mutable_objective()->add_coeffs(-10000); + } + + if (!absl::GetFlag(FLAGS_dump_dir).empty()) { + const std::string name = + file::JoinPath(absl::GetFlag(FLAGS_dump_dir), + absl::StrCat(GetSeedEnvName(), ".pb.txt")); + LOG(INFO) << "Dumping model to '" << name << "'."; + CHECK_OK(file::SetTextProto(name, model_proto, file::Defaults())); + } + + SatParameters params; + params.set_cp_model_presolve(true); + const CpSolverResponse response_with = + SolveWithParameters(model_proto, params); + params.set_cp_model_presolve(false); + const CpSolverResponse response_without = + SolveWithParameters(model_proto, params); + EXPECT_EQ(response_with.status(), response_without.status()); + EXPECT_NEAR(response_with.objective_value(), + response_without.objective_value(), 1e-9); +} + +INSTANTIATE_TEST_SUITE_P(All, RandomPreprocessorTest, + ::testing::Range(0, DEBUG_MODE ? 500 : 5000)); + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_symmetries_test.cc b/ortools/sat/cp_model_symmetries_test.cc new file mode 100644 index 0000000000..0974788701 --- /dev/null +++ b/ortools/sat/cp_model_symmetries_test.cc @@ -0,0 +1,687 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_symmetries.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "gtest/gtest.h" +#include "ortools/algorithms/sparse_permutation.h" +#include "ortools/base/gmock.h" +#include "ortools/base/parse_test_proto.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/model.h" +#include "ortools/sat/presolve_context.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/util/logging.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::google::protobuf::contrib::parse_proto::ParseTestProto; + +const char kBaseModel[] = R"pb( + variables { + name: 'x' + domain: [ -5, 5 ] + } + variables { + name: 'y' + domain: [ 0, 10 ] + } + variables { + name: 'z' + domain: [ 0, 10 ] + } + constraints { + linear { + domain: [ 0, 10 ] + vars: [ 0, 1, 2 ] + coeffs: [ 1, 2, 2 ] + } + } + constraints { + linear { + domain: [ 2, 12 ] + vars: [ 0, 1, 2 ] + coeffs: [ 3, 2, 2 ] + } + } +)pb"; + +TEST(FindCpModelSymmetries, FindsSymmetry) { + const CpModelProto model = ParseTestProto(kBaseModel); + + std::vector> generators; + SolverLogger logger; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(1 2)"); +} + +TEST(FindCpModelSymmetries, NoSymmetryIfDifferentVariableBounds) { + CpModelProto model = ParseTestProto(kBaseModel); + model.mutable_variables(1)->set_domain(1, 20); + + std::vector> generators; + SolverLogger logger; + + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 0); +} + +TEST(FindCpModelSymmetries, NoSymmetryIfDifferentConstraintCoefficients) { + CpModelProto model = ParseTestProto(kBaseModel); + model.mutable_constraints(1)->mutable_linear()->set_coeffs(1, 1); + SolverLogger logger; + + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 0); +} + +TEST(FindCpModelSymmetries, NoSymmetryIfDifferentObjectiveCoefficients) { + CpModelProto model = ParseTestProto(kBaseModel); + model.mutable_objective()->add_vars(1); + model.mutable_objective()->add_coeffs(1); + SolverLogger logger; + + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 0); +} + +const char kConstraintSymmetryModel[] = R"pb( + variables { + name: 'x' + domain: [ -5, 5 ] + } + variables { + name: 'y' + domain: [ 0, 10 ] + } + variables { + name: 'z' + domain: [ 0, 10 ] + } + constraints { + linear { + domain: [ 0, 10 ] + vars: [ 0, 1, 2 ] + coeffs: [ 1, 2, 3 ] + } + } + constraints { + linear { + domain: [ 0, 10 ] + vars: [ 0, 1, 2 ] + coeffs: [ 1, 3, 2 ] + } + } +)pb"; + +TEST(FindCpModelSymmetries, FindsSymmetryIfSameConstraintBounds) { + CpModelProto model = ParseTestProto(kConstraintSymmetryModel); + SolverLogger logger; + + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(1 2)"); + + // Make sure that if the constraint bounds are different, the symmetry is + // broken. + model.mutable_constraints(1)->mutable_linear()->set_domain(1, 20); + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 0); +} + +TEST(FindCpModelSymmetries, + NoSymmetryIfDifferentConstraintEnforcementLiterals) { + CpModelProto model = ParseTestProto(kConstraintSymmetryModel); + model.mutable_constraints(0)->add_enforcement_literal(0); + model.mutable_constraints(1)->add_enforcement_literal(1); + SolverLogger logger; + + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 0); +} + +TEST(FindCpModelSymmetries, FindsSymmetryIfSameConstraintEnforcementLiterals) { + CpModelProto model = ParseTestProto(kConstraintSymmetryModel); + model.mutable_constraints(0)->add_enforcement_literal(0); + model.mutable_constraints(1)->add_enforcement_literal(0); + SolverLogger logger; + + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(1 2)"); +} + +TEST(FindCpModelSymmetries, + FindsSymmetryIfSameNegativeConstraintEnforcementLiterals) { + CpModelProto model = ParseTestProto(kConstraintSymmetryModel); + model.mutable_constraints(0)->add_enforcement_literal(-1); + model.mutable_constraints(1)->add_enforcement_literal(-1); + SolverLogger logger; + + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(1 2)"); +} + +TEST(FindCpModelSymmetries, LinMaxConstraint) { + CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + constraints { + lin_max { + target { vars: 0 coeffs: 1 } + exprs { vars: 1, coeffs: 2 } + exprs { vars: 2, coeffs: 2 } + exprs { vars: 3, coeffs: 3 } + } + } + )pb"); + SolverLogger logger; + + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(1 2)"); +} + +TEST(FindCpModelSymmetries, UnsupportedConstraintTypeReturnsNoGenerators) { + CpModelProto model = ParseTestProto(R"pb( + variables { + name: 'x' + domain: [ -5, 5 ] + } + variables { + name: 'y' + domain: [ 0, 10 ] + } + variables { + name: 'z' + domain: [ 0, 10 ] + } + constraints { routes {} } + )pb"); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 0); +} + +// We ignore variables that do not appear in any constraint. +TEST(FindCpModelSymmetries, FindsSymmetryIfNoConstraints) { + CpModelProto model = ParseTestProto(R"pb( + variables { + name: 'x' + domain: [ 0, 10 ] + } + variables { + name: 'y' + domain: [ -5, 5 ] + } + variables { + name: 'z' + domain: [ 0, 10 ] + } + )pb"); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + ASSERT_EQ(generators.size(), 0); +} + +TEST(FindCpModelSymmetries, NoSymmetryIfDuplicateConstraints) { + CpModelProto model = ParseTestProto(R"pb( + variables { + name: 'x' + domain: [ -5, 5 ] + } + variables { + name: 'y' + domain: [ 0, 10 ] + } + variables { + name: 'z' + domain: [ -5, 10 ] + } + constraints { + linear { + domain: [ 0, 10 ] + vars: [ 0, 1, 2 ] + coeffs: [ 1, 2, 3 ] + } + } + constraints { + linear { + domain: [ 0, 10 ] + vars: [ 0, 1, 2 ] + coeffs: [ 1, 2, 3 ] + } + } + )pb"); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 0); +} + +// a => not(b) +// not(a) => c >= 4 +// not(b) => d >= 4 +TEST(FindCpModelSymmetries, ImplicationTestThatUsedToFail) { + CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + enforcement_literal: 0 + bool_and { literals: [ -2 ] } + } + constraints { + enforcement_literal: -1 + linear { + vars: [ 2 ] + coeffs: [ 1 ] + domain: [ 4, 10 ] + } + } + constraints { + enforcement_literal: -2 + linear { + vars: [ 3 ] + coeffs: [ 1 ] + domain: [ 4, 10 ] + } + } + )pb"); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 1) (2 3)"); +} + +TEST(DetectAndAddSymmetryToProto, BasicTest) { + // A model with one (0, 1) (2, 3) symmetry. + CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + enforcement_literal: 0 + bool_and { literals: [ -2 ] } + } + constraints { + enforcement_literal: -1 + linear { + vars: [ 2 ] + coeffs: [ 1 ] + domain: [ 4, 10 ] + } + } + constraints { + enforcement_literal: -2 + linear { + vars: [ 3 ] + coeffs: [ 1 ] + domain: [ 4, 10 ] + } + } + )pb"); + + SolverLogger logger; + SatParameters params; + params.set_log_search_progress(true); + DetectAndAddSymmetryToProto(params, &model, &logger); + + // TODO(user): canonicalize the order in each cycle? + const SymmetryProto expected = ParseTestProto(R"pb( + permutations { + support: [ 1, 0, 3, 2 ] + cycle_sizes: [ 2, 2 ] + } + )pb"); + + EXPECT_THAT(model.symmetry(), testing::EqualsProto(expected)); +} + +const char kBooleanModel[] = R"pb( + variables { + name: 'x' + domain: [ 0, 1 ] + } + variables { + name: 'y' + domain: [ 0, 1 ] + } + variables { + name: 'z' + domain: [ 0, 1 ] + } +)pb"; + +TEST(FindCpModelSymmetries, FindsSymmetryInBoolOr) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { bool_or { literals: [ 0, 1 ] } } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 1)"); +} + +TEST(FindCpModelSymmetries, FindsSymmetryInNegatedBoolOr) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { bool_or { literals: [ -1, -3 ] } } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 2)"); +} + +TEST(FindCpModelSymmetries, FindsSymmetryInBoolOrWithEnforcementLiteral) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { + enforcement_literal: 1 + bool_or { literals: [ 0, 2 ] } + } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 2)"); +} + +TEST(FindCpModelSymmetries, FindsSymmetryInBoolXor) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { bool_xor { literals: [ 0, 2 ] } } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 2)"); +} + +TEST(FindCpModelSymmetries, FindsSymmetryInNegatedBoolXor) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { bool_or { literals: [ -2, -3 ] } } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(1 2)"); +} + +TEST(FindCpModelSymmetries, FindsSymmetryInBoolXorWithEnforcementLiteral) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { + enforcement_literal: 0 + bool_or { literals: [ 1, 2 ] } + } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(1 2)"); +} + +TEST(FindCpModelSymmetries, FindsSymmetryInBoolAnd) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { + enforcement_literal: 1 + bool_and { literals: [ 0, 2 ] } + } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 2)"); +} + +TEST(FindCpModelSymmetries, FindsSymmetryInNegatedBoolAnd) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { + enforcement_literal: 2 + bool_and { literals: [ -1, -2 ] } + } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 1)"); +} + +TEST(FindCpModelSymmetries, FindsSymmetryInBoolAndWithEnforcementLiteral) { + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { + enforcement_literal: 2 + bool_and { literals: [ 0, 1 ] } + } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 1)"); +} + +TEST(FindCpModelSymmetries, + FindsSymmetryInBoolOrsAndBoolAndWithEnforcementLiteral) { + // The two BoolOrs and the BoolAnd are equivalent, so this should find a + // symmetry between literals 0 and 1. + CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( + constraints { bool_or { literals: [ 2, 1 ] } } + constraints { bool_or { literals: [ 0, 2 ] } } + constraints { + enforcement_literal: -3 + bool_and { literals: [ 0, 1 ] } + } + )pb")); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 1)"); +} + +TEST(FindCpModelSymmetries, BasicSchedulingCase) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } # start 1 + variables { domain: [ 0, 10 ] } # start 2 + variables { domain: [ 0, 10 ] } # start 3 + constraints { + interval { + start { vars: 0 coeffs: 1 offset: 0 } + size { offset: 4 } + end { vars: 0 coeffs: 1 offset: 4 } + } + } + constraints { + interval { + start { vars: 1 coeffs: 1 offset: 0 } + size { offset: 5 } + end { vars: 1 coeffs: 1 offset: 5 } + } + } + constraints { + interval { + start { vars: 2 coeffs: 1 offset: 0 } + size { offset: 4 } + end { vars: 2 coeffs: 1 offset: 4 } + } + } + constraints { no_overlap { intervals: [ 0, 1, 2 ] } } + )pb"); + + SolverLogger logger; + std::vector> generators; + FindCpModelSymmetries({}, model, &generators, + std::numeric_limits::infinity(), &logger); + + // The two intervals with the same size can be swapped. + ASSERT_EQ(generators.size(), 1); + EXPECT_EQ(generators[0]->DebugString(), "(0 2)"); +} + +// Assigning n items to b identical bins is an example of orbitope since the +// bins can be freely permuted. +TEST(FindCpModelSymmetries, BinPacking) { + constexpr int num_items = 10; + constexpr int num_bins = 7; + CpModelProto proto; + + // One Boolean per possible assignment. + int item_to_bin[num_items][num_bins]; + for (int i = 0; i < num_items; ++i) { + for (int b = 0; b < num_bins; ++b) { + item_to_bin[i][b] = proto.variables().size(); + auto* var = proto.add_variables(); + var->add_domain(0); + var->add_domain(1); + } + } + + // At most one per row. + for (int i = 0; i < num_items; ++i) { + auto* amo = proto.add_constraints()->mutable_at_most_one(); + for (int b = 0; b < num_bins; ++b) { + amo->add_literals(item_to_bin[i][b]); + } + } + + // Simple capacity constraint. + for (int b = 0; b < num_bins; ++b) { + auto* linear = proto.add_constraints()->mutable_linear(); + for (int i = 0; i < num_items; ++i) { + linear->add_vars(item_to_bin[i][b]); + linear->add_coeffs(i + 1); + } + linear->add_domain(0); + linear->add_domain(10); // <= 10. + } + + Model model; + model.GetOrCreate()->EnableLogging(true); + model.GetOrCreate()->SetLogToStdOut(true); + PresolveContext context(&model, &proto, nullptr); + context.InitializeNewDomains(); + context.UpdateNewConstraintsVariableUsage(); + context.ReadObjectiveFromProto(); + EXPECT_TRUE(DetectAndExploitSymmetriesInPresolve(&context)); + context.LogInfo(); + + // We have a 10 x 7 orbitope. + // Note that here we do not do propagation, just fixing to zero according + // to the orbitope and the at most ones. We should fix 6 on the first row, + // and one less per row after that. + for (int i = 0; i < num_items; ++i) { + int num_fixed = 0; + for (int b = 0; b < num_bins; ++b) { + if (context.IsFixed(item_to_bin[i][b])) ++num_fixed; + } + CHECK_EQ(num_fixed, std::max(0, 6 - i)) << i; + } +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_table.cc b/ortools/sat/cp_model_table.cc new file mode 100644 index 0000000000..ee574610e2 --- /dev/null +++ b/ortools/sat/cp_model_table.cc @@ -0,0 +1,323 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_table.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "ortools/base/stl_util.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/presolve_context.h" + +namespace operations_research { +namespace sat { + +void CanonicalizeTable(PresolveContext* context, ConstraintProto* ct) { + if (context->ModelIsUnsat()) return; + + const int num_exprs = ct->table().exprs_size(); + const int num_tuples = ct->table().values_size() / num_exprs; + + // Detect expressions sharing the same variable as a previous expression. + absl::flat_hash_map var_to_position; + + // The mapping between the position in the original list of expressions, and + // the position in the reduced list of expressions. + std::vector> position_mapping(num_exprs, std::nullopt); + int num_shared_vars = 0; + for (int i = 0; i < num_exprs; ++i) { + const LinearExpressionProto& expr = ct->table().exprs(i); + if (context->IsFixed(expr)) continue; + + const int var = expr.vars(0); + const auto [it, inserted] = + var_to_position.insert({var, var_to_position.size()}); + if (!inserted) { + ++num_shared_vars; + position_mapping[i] = it->second; + } + } + + const int num_kept_exprs = num_exprs - num_shared_vars; + + std::vector> new_tuples; + new_tuples.reserve(num_tuples); + + std::vector new_scaled_values; + new_scaled_values.reserve(num_kept_exprs); + + for (int t = 0; t < num_tuples; ++t) { + bool tuple_is_valid = true; + new_scaled_values.clear(); + + for (int e = 0; e < num_exprs; ++e) { + const int64_t value = ct->table().values(t * num_exprs + e); + const LinearExpressionProto& expr = ct->table().exprs(e); + if (context->IsFixed(expr)) { + if (value != context->FixedValue(expr)) { + tuple_is_valid = false; + break; + } + new_scaled_values.push_back(value); + } else if (position_mapping[e].has_value()) { + const int var_first_position = position_mapping[e].value(); + const int64_t var_value = new_scaled_values[var_first_position]; + const int64_t forced_value = AffineExpressionValueAt(expr, var_value); + if (value != forced_value) { + tuple_is_valid = false; + break; + } + } else { + if (!context->DomainContains(expr, value)) { + tuple_is_valid = false; + break; + } + new_scaled_values.push_back(GetInnerVarValue(expr, value)); + } + } + + if (tuple_is_valid) { + DCHECK_EQ(new_scaled_values.size(), num_kept_exprs); + new_tuples.push_back(new_scaled_values); + } + } + + // Remove all scaling on expressions as we have stored the inner values. + for (int e = 0; e < num_exprs; ++e) { + if (position_mapping[e].has_value()) continue; + if (context->IsFixed(ct->table().exprs(e))) continue; + DCHECK_EQ(ct->table().exprs(e).coeffs_size(), 1); + ct->mutable_table()->mutable_exprs(e)->set_offset(0); + ct->mutable_table()->mutable_exprs(e)->set_coeffs(0, 1); + } + + if (num_kept_exprs < num_exprs) { + int index = 0; + for (int e = 0; e < num_exprs; ++e) { + if (position_mapping[e].has_value()) continue; + ct->mutable_table()->mutable_exprs()->SwapElements(index++, e); + } + CHECK_EQ(index, num_kept_exprs); + ct->mutable_table()->mutable_exprs()->DeleteSubrange(index, + num_exprs - index); + context->UpdateRuleStats("table: remove expressions"); + } + + gtl::STLSortAndRemoveDuplicates(&new_tuples); + if (new_tuples.size() < num_tuples) { + context->UpdateRuleStats("table: remove tuples"); + } + + // Write sorted tuples. + ct->mutable_table()->clear_values(); + for (const std::vector& tuple : new_tuples) { + ct->mutable_table()->mutable_values()->Add(tuple.begin(), tuple.end()); + } +} + +void RemoveFixedColumnsFromTable(PresolveContext* context, + ConstraintProto* ct) { + if (context->ModelIsUnsat()) return; + const int num_exprs = ct->table().exprs_size(); + const int num_tuples = ct->table().values_size() / num_exprs; + std::vector is_fixed(num_exprs, false); + int num_fixed_exprs = 0; + for (int e = 0; e < num_exprs; ++e) { + is_fixed[e] = context->IsFixed(ct->table().exprs(e)); + num_fixed_exprs += is_fixed[e]; + } + if (num_fixed_exprs == 0) return; + + int num_kept_exprs = num_exprs - num_fixed_exprs; + + int index = 0; + for (int e = 0; e < num_exprs; ++e) { + if (is_fixed[e]) continue; + ct->mutable_table()->mutable_exprs()->SwapElements(index++, e); + } + CHECK_EQ(index, num_kept_exprs); + ct->mutable_table()->mutable_exprs()->DeleteSubrange(index, + num_exprs - index); + index = 0; + for (int t = 0; t < num_tuples; ++t) { + for (int e = 0; e < num_exprs; ++e) { + if (is_fixed[e]) continue; + ct->mutable_table()->set_values(index++, + ct->table().values(t * num_exprs + e)); + } + } + CHECK_EQ(index, num_tuples * num_kept_exprs); + ct->mutable_table()->mutable_values()->Truncate(index); + + context->UpdateRuleStats("table: remove fixed columns"); +} + +void CompressTuples(absl::Span domain_sizes, + std::vector>* tuples) { + if (tuples->empty()) return; + + // Remove duplicates if any. + gtl::STLSortAndRemoveDuplicates(tuples); + + const int num_vars = (*tuples)[0].size(); + + std::vector to_remove; + std::vector tuple_minus_var_i(num_vars - 1); + for (int i = 0; i < num_vars; ++i) { + const int domain_size = domain_sizes[i]; + if (domain_size == 1) continue; + absl::flat_hash_map, std::vector> + masked_tuples_to_indices; + for (int t = 0; t < tuples->size(); ++t) { + int out = 0; + for (int j = 0; j < num_vars; ++j) { + if (i == j) continue; + tuple_minus_var_i[out++] = (*tuples)[t][j]; + } + masked_tuples_to_indices[tuple_minus_var_i].push_back(t); + } + to_remove.clear(); + for (const auto& it : masked_tuples_to_indices) { + if (it.second.size() != domain_size) continue; + (*tuples)[it.second.front()][i] = kTableAnyValue; + to_remove.insert(to_remove.end(), it.second.begin() + 1, it.second.end()); + } + std::sort(to_remove.begin(), to_remove.end(), std::greater()); + for (const int t : to_remove) { + (*tuples)[t] = tuples->back(); + tuples->pop_back(); + } + } +} + +namespace { + +// We will call FullyCompressTuplesRecursive() for a set of prefixes of the +// original tuples, each having the same suffix (in reversed_suffix). +// +// For such set, we will compress it on the last variable of the prefixes. We +// will then for each unique compressed set of value of that variable, call +// a new FullyCompressTuplesRecursive() on the corresponding subset. +void FullyCompressTuplesRecursive( + absl::Span domain_sizes, + absl::Span> tuples, + std::vector>* reversed_suffix, + std::vector>>* output) { + struct TempData { + absl::InlinedVector values; + int index; + + bool operator<(const TempData& other) const { + return values < other.values; + } + }; + std::vector temp_data; + + CHECK(!tuples.empty()); + CHECK(!tuples[0].empty()); + const int64_t domain_size = domain_sizes[tuples[0].size() - 1]; + + // Sort tuples and regroup common prefix in temp_data. + std::sort(tuples.begin(), tuples.end()); + for (int i = 0; i < tuples.size();) { + const int start = i; + temp_data.push_back({{tuples[start].back()}, start}); + tuples[start].pop_back(); + for (++i; i < tuples.size(); ++i) { + const int64_t v = tuples[i].back(); + tuples[i].pop_back(); + if (tuples[i] == tuples[start]) { + temp_data.back().values.push_back(v); + } else { + tuples[i].push_back(v); + break; + } + } + + // If one of the value is the special value kTableAnyValue, we convert + // it to the "empty means any" format. + for (const int64_t v : temp_data.back().values) { + if (v == kTableAnyValue) { + temp_data.back().values.clear(); + break; + } + } + gtl::STLSortAndRemoveDuplicates(&temp_data.back().values); + + // If values cover the whole domain, we clear the vector. This allows to + // use less space and avoid creating unneeded clauses. + if (temp_data.back().values.size() == domain_size) { + temp_data.back().values.clear(); + } + } + + if (temp_data.size() == 1) { + output->push_back({}); + for (const int64_t v : tuples[temp_data[0].index]) { + if (v == kTableAnyValue) { + output->back().push_back({}); + } else { + output->back().push_back({v}); + } + } + output->back().push_back(temp_data[0].values); + for (int i = reversed_suffix->size(); --i >= 0;) { + output->back().push_back((*reversed_suffix)[i]); + } + return; + } + + // Sort temp_data and make recursive call for all tuples that share the + // same suffix. + std::sort(temp_data.begin(), temp_data.end()); + std::vector> temp_tuples; + for (int i = 0; i < temp_data.size();) { + reversed_suffix->push_back(temp_data[i].values); + const int start = i; + temp_tuples.clear(); + for (; i < temp_data.size(); i++) { + if (temp_data[start].values != temp_data[i].values) break; + temp_tuples.push_back(tuples[temp_data[i].index]); + } + FullyCompressTuplesRecursive(domain_sizes, absl::MakeSpan(temp_tuples), + reversed_suffix, output); + reversed_suffix->pop_back(); + } +} + +} // namespace + +// TODO(user): We can probably reuse the tuples memory always and never create +// new one. We should also be able to code an iterative version of this. Note +// however that the recursion level is bounded by the number of columns which +// should be small. +std::vector>> FullyCompressTuples( + absl::Span domain_sizes, + std::vector>* tuples) { + std::vector> reversed_suffix; + std::vector>> output; + FullyCompressTuplesRecursive(domain_sizes, absl::MakeSpan(*tuples), + &reversed_suffix, &output); + return output; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_table.h b/ortools/sat/cp_model_table.h new file mode 100644 index 0000000000..30cc46c431 --- /dev/null +++ b/ortools/sat/cp_model_table.h @@ -0,0 +1,73 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_CP_MODEL_TABLE_H_ +#define OR_TOOLS_SAT_CP_MODEL_TABLE_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/presolve_context.h" + +namespace operations_research { +namespace sat { + +// Canonicalizes the table constraint by removing all unreachable tuples, and +// all columns which have the same variable of a previous column. +// +// This also sort all the tuples. +void CanonicalizeTable(PresolveContext* context, ConstraintProto* ct); + +// Removed all fixed columns from the table. +void RemoveFixedColumnsFromTable(PresolveContext* context, ConstraintProto* ct); + +// This method tries to compress a list of tuples by merging complementary +// tuples, that is a set of tuples that only differ on one variable, and that +// cover the domain of the variable. In that case, it will keep only one tuple, +// and replace the value for variable by any_value, the equivalent of '*' in +// regexps. +// +// This method is exposed for testing purposes. +constexpr int64_t kTableAnyValue = std::numeric_limits::min(); +void CompressTuples(absl::Span domain_sizes, + std::vector>* tuples); + +// Similar to CompressTuples() but produces a final table where each cell is +// a set of value. This should result in a table that can still be encoded +// efficiently in SAT but with less tuples and thus less extra Booleans. Note +// that if a set of value is empty, it is interpreted at "any" so we can gain +// some space. +// +// The passed tuples vector is used as temporary memory and is detroyed. +// We interpret kTableAnyValue as an "any" tuple. +// +// TODO(user): To reduce memory, we could return some absl::Span in the last +// layer instead of vector. +// +// TODO(user): The final compression is depend on the order of the variables. +// For instance the table (1,1)(1,2)(1,3),(1,4),(2,3) can either be compressed +// as (1,*)(2,3) or (1,{1,2,4})({1,3},3). More experiment are needed to devise +// a better heuristic. It might for example be good to call CompressTuples() +// first. +std::vector>> FullyCompressTuples( + absl::Span domain_sizes, + std::vector>* tuples); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_CP_MODEL_TABLE_H_ diff --git a/ortools/sat/cp_model_table_test.cc b/ortools/sat/cp_model_table_test.cc new file mode 100644 index 0000000000..72c8ba0d37 --- /dev/null +++ b/ortools/sat/cp_model_table_test.cc @@ -0,0 +1,142 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_table.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/sat_parameters.pb.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(CompressTuplesTest, OneAny) { + const std::vector domain_sizes = {2, 2, 2, 4}; + std::vector> tuples = { + {0, 0, 0, 0}, + {1, 1, 0, 2}, + {0, 0, 1, 3}, + {0, 1, 1, 3}, + }; + CompressTuples(domain_sizes, &tuples); + const std::vector>& expected_tuples = { + {0, 0, 0, 0}, + {0, kTableAnyValue, 1, 3}, // Result is sorted. + {1, 1, 0, 2}, + }; + EXPECT_EQ(tuples, expected_tuples); +} + +TEST(CompressTuplesTest, NotPerfect) { + const std::vector domain_sizes = {3, 3}; + std::vector> tuples = { + {0, 0}, {0, 1}, {0, 2}, {1, 2}, {2, 2}, + }; + CompressTuples(domain_sizes, &tuples); + + // Here we could return instead: + // {0, kint64min} + // {kint64min, 2} + const std::vector>& expected_tuples = { + {0, 0}, + {0, 1}, + {kTableAnyValue, 2}, + }; + EXPECT_EQ(tuples, expected_tuples); +} + +TEST(FullyCompressTuplesTest, BasicTest) { + const std::vector domain_sizes = {4, 4}; + std::vector> tuples = { + {0, 1}, {0, 2}, {0, 3}, {1, 1}, {1, 2}, + }; + const auto result = FullyCompressTuples(domain_sizes, &tuples); + const std::vector>>& expected = { + {{1}, {1, 2}}, + {{0}, {1, 2, 3}}, + }; + EXPECT_EQ(result, expected); +} + +TEST(CompressTuplesTest, BasicTest2) { + const std::vector domain_sizes = {4, 4, 4, 4}; + std::vector> tuples = { + {0, 0, 0, 0}, + {1, 1, 0, 2}, + {0, 0, 1, 3}, + {0, 1, 1, 3}, + }; + const auto result = FullyCompressTuples(domain_sizes, &tuples); + const std::vector>>& expected = { + {{0}, {0}, {0}, {0}}, {{1}, {1}, {0}, {2}}, {{0}, {0, 1}, {1}, {3}}}; + EXPECT_EQ(result, expected); +} + +TEST(CompressTuplesTest, BasicTest3) { + const std::vector domain_sizes = {4, 4, 4, 4}; + std::vector> tuples = { + {0, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 0, 0}, {1, 1, 0, 0}, + {0, 0, 2, 0}, {0, 1, 2, 0}, {1, 0, 2, 0}, {1, 1, 2, 0}, + }; + const auto result = FullyCompressTuples(domain_sizes, &tuples); + const std::vector>>& expected = { + {{0, 1}, {0, 1}, {0, 2}, {0}}}; + EXPECT_EQ(result, expected); +} + +TEST(FullyCompressTuplesTest, BasicTestWithAnyValue) { + const std::vector domain_sizes = {4, 3}; + std::vector> tuples = { + {0, 1}, {0, 2}, {0, 3}, {1, 1}, {1, 2}, + }; + const auto result = FullyCompressTuples(domain_sizes, &tuples); + const std::vector>>& expected = { + {{0}, {}}, + {{1}, {1, 2}}, + }; + EXPECT_EQ(result, expected); +} + +TEST(FullyCompressTuplesTest, ConvertAnyValueRepresentation) { + const std::vector domain_sizes = {4, 3}; + std::vector> tuples = {{0, kTableAnyValue}, + {kTableAnyValue, 2}}; + const auto result = FullyCompressTuples(domain_sizes, &tuples); + const std::vector>>& expected = { + {{0}, {}}, + {{}, {2}}, + }; + EXPECT_EQ(result, expected); +} + +TEST(FullyCompressTuplesTest, ConvertAnyValueRepresentation2) { + const std::vector domain_sizes = {4, 3, 2, 3}; + std::vector> tuples = { + {0, kTableAnyValue, 3, kTableAnyValue}}; + const auto result = FullyCompressTuples(domain_sizes, &tuples); + const std::vector>>& expected = { + {{0}, {}, {3}, {}}, + }; + EXPECT_EQ(result, expected); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_test_utils.cc b/ortools/sat/cp_model_test_utils.cc new file mode 100644 index 0000000000..35de434e78 --- /dev/null +++ b/ortools/sat/cp_model_test_utils.cc @@ -0,0 +1,111 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_test_utils.h" + +#include + +#include +#include + +#include "absl/random/random.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_utils.h" + +namespace operations_research { +namespace sat { + +CpModelProto Random3SatProblem(int num_variables, + double proportion_of_constraints) { + CpModelProto result; + absl::BitGen random; + result.set_name("Random 3-SAT"); + for (int i = 0; i < num_variables; ++i) { + sat::IntegerVariableProto* var = result.add_variables(); + var->add_domain(0); + var->add_domain(1); + } + const int num_constraints = proportion_of_constraints * num_variables; + for (int i = 0; i < num_constraints; ++i) { + auto* ct = result.add_constraints()->mutable_bool_or(); + std::vector clause; + while (ct->literals_size() != 3) { + const int literal = + absl::Uniform(random, NegatedRef(num_variables - 1), num_variables); + bool is_already_present = false; + for (const int lit : ct->literals()) { + if (lit != literal) continue; + is_already_present = true; + break; + } + if (!is_already_present) ct->add_literals(literal); + } + } + return result; +} + +CpModelProto RandomLinearProblem(int num_variables, int num_constraints) { + CpModelProto result; + absl::BitGen random; + result.set_name("Random 0-1 linear problem"); + for (int i = 0; i < num_variables; ++i) { + sat::IntegerVariableProto* var = result.add_variables(); + var->add_domain(0); + var->add_domain(1); + } + for (int i = 0; i < num_constraints; ++i) { + // Sum >= num_variables / 10. + auto* ct = result.add_constraints()->mutable_linear(); + const int min_value = num_variables / 10; + ct->add_domain(min_value); + ct->add_domain(std::numeric_limits::max()); + for (int v = 0; v < num_variables; ++v) { + if (absl::Bernoulli(random, 0.5) || + // To ensure that the constraint is feasible, we enforce that it has + // at least the 'minimum' number of terms. This clause should only + // rarely be used, when num_variables is high. + num_variables - v <= min_value - ct->vars_size()) { + ct->add_vars(v); + ct->add_coeffs(1); + } + } + } + + // Objective: minimize variables at one. + { + const int objective_var_index = result.variables_size(); + { + sat::IntegerVariableProto* var = result.add_variables(); + var->add_domain(0); + var->add_domain(num_variables); + } + result.mutable_objective()->add_vars(objective_var_index); + result.mutable_objective()->add_coeffs(1); + + // Sum of all other variables == 0 + auto* ct = result.add_constraints()->mutable_linear(); + ct->add_domain(0); + ct->add_domain(0); + for (int v = 0; v < num_variables; ++v) { + ct->add_vars(v); + ct->add_coeffs(1); + } + ct->add_vars(objective_var_index); + ct->add_coeffs(-1); + } + + return result; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_test_utils.h b/ortools/sat/cp_model_test_utils.h new file mode 100644 index 0000000000..07837a4e38 --- /dev/null +++ b/ortools/sat/cp_model_test_utils.h @@ -0,0 +1,36 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_CP_MODEL_TEST_UTILS_H_ +#define OR_TOOLS_SAT_CP_MODEL_TEST_UTILS_H_ + +#include "ortools/sat/cp_model.pb.h" + +namespace operations_research { +namespace sat { + +// Generates a random 3-SAT problem with a number of constraints given by: +// num_variables * proportions_of_constraints. With the default proportion +// value, we are around the transition SAT/UNSAT. +CpModelProto Random3SatProblem(int num_variables, + double proportion_of_constraints = 4.26); + +// Generates a random 0-1 "covering" optimization linear problem: +// - Each constraint has density ~0.5 and ask for a sum >= num_variables / 10. +// - The objective is to minimize the number of variables at 1. +CpModelProto RandomLinearProblem(int num_variables, int num_constraints); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_CP_MODEL_TEST_UTILS_H_ diff --git a/ortools/sat/csharp/CpModel.cs b/ortools/sat/csharp/CpModel.cs index fb52881249..ce7041c583 100644 --- a/ortools/sat/csharp/CpModel.cs +++ b/ortools/sat/csharp/CpModel.cs @@ -329,23 +329,23 @@ public MultipleCircuitConstraint AddMultipleCircuit() return ct; } - /** - * - * Adds AllowedAssignments(expressions). - * - * - * An AllowedAssignments constraint is a constraint on an array of affine - * expressions (a * var + b) that forces, when all expressions are fixed to a single - * value, that the corresponding list of values is equal to one of the tuples of the - * tupleList. - * - * - * a list of affine expressions (a * var + b) - * an instance of the TableConstraint class without any tuples. Tuples can be added - * directly to the table constraint - * - */ - public TableConstraint AddAllowedAssignments(IEnumerable exprs) + /** + * + * Adds AllowedAssignments(expressions). + * + * + * An AllowedAssignments constraint is a constraint on an array of affine + * expressions (a * var + b) that forces, when all expressions are fixed to a single + * value, that the corresponding list of values is equal to one of the tuples of the + * tupleList. + * + * + * a list of affine expressions (a * var + b) + * an instance of the TableConstraint class without any tuples. Tuples can be added + * directly to the table constraint + * + */ + public TableConstraint AddAllowedAssignments(IEnumerable exprs) { TableConstraintProto table = new TableConstraintProto(); table.Vars.TrySetCapacity(exprs); diff --git a/ortools/sat/csharp/SatSolverTests.cs b/ortools/sat/csharp/SatSolverTests.cs index 8b3ede166a..6d18a08d5a 100644 --- a/ortools/sat/csharp/SatSolverTests.cs +++ b/ortools/sat/csharp/SatSolverTests.cs @@ -398,7 +398,7 @@ public void ValueElement() CpModel model = new CpModel(); IntVar v1 = model.NewIntVar(1, 10, "v1"); IntVar v2 = model.NewIntVar(1, 10, "v2"); - model.AddElement(v1 + 2, new int[] {1, 3, 5}, 5 - v2); + model.AddElement(v1 + 2, new int[] { 1, 3, 5 }, 5 - v2); Assert.Equal(3, model.Model.Constraints[0].Element.Exprs.Count); } @@ -410,7 +410,7 @@ public void ExprElement() IntVar x = model.NewIntVar(0, 5, "x"); IntVar y = model.NewIntVar(0, 5, "y"); IntVar z = model.NewIntVar(0, 5, "z"); - model.AddElement(v1, new LinearExpr[] {x + 2, -y, LinearExpr.Constant(5), 2 * z}, 5 - v2); + model.AddElement(v1, new LinearExpr[] { x + 2, -y, LinearExpr.Constant(5), 2 * z }, 5 - v2); Assert.Equal(4, model.Model.Constraints[0].Element.Exprs.Count); } diff --git a/ortools/sat/diophantine_test.cc b/ortools/sat/diophantine_test.cc new file mode 100644 index 0000000000..b07fb35c7d --- /dev/null +++ b/ortools/sat/diophantine_test.cc @@ -0,0 +1,225 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/diophantine.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/log_severity.h" +#include "absl/numeric/int128.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "gtest/gtest.h" +#include "ortools/sat/util.h" + +namespace operations_research::sat { + +namespace { + +TEST(ReduceModuloBasis, LookZeroRows) { + const std::vector> basis = {{0, 1}, {0, 0, 1}}; + std::vector v = {1, 2, 3}; + ReduceModuloBasis(basis, 0, v); + EXPECT_EQ(v, std::vector({1, 2, 3})); +} + +TEST(ReduceModuloBasis, LookOneRows) { + const std::vector> basis = {{0, 1}, {0, 0, 1}}; + std::vector v = {1, 2, 3}; + ReduceModuloBasis(basis, 1, v); + EXPECT_EQ(v, std::vector({1, 0, 3})); +} + +TEST(ReduceModuloBasis, LookTwoRows) { + const std::vector> basis = {{0, 1}, {0, 0, 1}}; + std::vector v = {1, 2, 3}; + ReduceModuloBasis(basis, 2, v); + EXPECT_EQ(v, std::vector({1, 0, 0})); +} + +template +T UniformNonZero(URBG&& random, T lo, T hi) { + T result = absl::Uniform(random, lo, hi - 1); + return result < 0 ? result : result + 1; +} + +class RandomTest : public ::testing::TestWithParam { + protected: + std::string GetSeedEnvName() { + return absl::StrFormat("TestCase%d", GetParam()); + } +}; + +TEST_P(RandomTest, ReduceModuloBasis) { + absl::BitGen random; + + const int num_rows = absl::Uniform(random, 1, 20); + + std::vector> basis; + for (int i = 0; i < num_rows; ++i) { + basis.emplace_back(i + 2); + for (int j = 0; j < i + 1; ++j) { + basis.back()[j] = absl::Uniform(random, -100, 100); + } + basis.back()[i + 1] = UniformNonZero(random, -100, 100); + } + std::vector v_reduced(num_rows + 1); + v_reduced[0] = absl::Uniform(random, -100, 100); + for (int i = 1; i <= num_rows; ++i) { + int pivot = std::abs(static_cast(basis[i - 1][i])); + v_reduced[i] = absl::Uniform(random, -FloorOfRatio(pivot, 2), + CeilOfRatio(pivot, 2)); + } + std::vector v = v_reduced; + for (int i = 0; i < num_rows; ++i) { + int m = absl::Uniform(random, -100, 100); + for (int j = 0; j < basis[i].size(); ++j) { + v[j] += m * basis[i][j]; + } + } + ReduceModuloBasis(basis, num_rows, v); + EXPECT_EQ(v, v_reduced); +} + +TEST_P(RandomTest, GreedyFastDecreasingGcd) { + absl::BitGen random; + + const int num_elements = absl::Uniform(random, 1, 50); + std::vector coeffs(num_elements); + for (int i = 0; i < num_elements; ++i) { + coeffs[i] = UniformNonZero(random, 1 + std::numeric_limits::min(), + std::numeric_limits::max()); + } + const std::vector order = GreedyFastDecreasingGcd(coeffs); + if (order.empty()) { + int64_t gcd = std::abs(coeffs[0]); + int64_t min_elem = std::abs(coeffs[0]); + for (int i = 1; i < num_elements; ++i) { + min_elem = std::min(min_elem, std::abs(coeffs[i])); + gcd = std::gcd(gcd, std::abs(coeffs[i])); + } + EXPECT_EQ(gcd, min_elem); + return; + } + + // order should be a permutation. + EXPECT_EQ(order.size(), num_elements); + std::vector seen(num_elements, false); + for (const int i : order) { + EXPECT_FALSE(seen[i]); + seen[i] = true; + } + + // GCD should be decreasing then static. + int64_t gcd = std::abs(coeffs[order[0]]); + bool constant = false; + int non_constant_terms = 1; + + for (int i = 1; i < num_elements; ++i) { + const int64_t new_gcd = std::gcd(gcd, std::abs(coeffs[order[i]])); + if (new_gcd != gcd) { + EXPECT_FALSE(constant); + gcd = new_gcd; + ++non_constant_terms; + } else { + constant = true; + } + } + EXPECT_GE(non_constant_terms, 2); // order should be empty. + EXPECT_LE(non_constant_terms, 15); +} + +TEST_P(RandomTest, SolveDiophantine) { + absl::BitGen random; + + // Creates a constraint and a solution. + const int num_elements = absl::Uniform(random, 1, 50); + std::vector coeffs(num_elements); + for (int i = 0; i < num_elements; ++i) { + coeffs[i] = UniformNonZero(random, -1000000, 1000000); + } + + std::vector particular_solution(num_elements); + std::vector lbs(num_elements); + std::vector ubs(num_elements); + absl::int128 rhs = 0; + for (int i = 0; i < num_elements; ++i) { + int64_t min = absl::Uniform(random, -1000000, 1000000); + int64_t max = absl::Uniform(random, -1000000, 1000000); + if (min > max) std::swap(min, max); + lbs[i] = min; + ubs[i] = max; + particular_solution[i] = absl::Uniform(random, min, max + 1); + rhs += coeffs[i] * particular_solution[i]; + } + DiophantineSolution solution = + SolveDiophantine(coeffs, static_cast(rhs), lbs, ubs); + if (solution.no_reformulation_needed) { + int64_t gcd = std::abs(coeffs[0]); + int64_t min_elem = std::abs(coeffs[0]); + for (int i = 1; i < num_elements; ++i) { + min_elem = std::min(min_elem, std::abs(coeffs[i])); + gcd = std::gcd(gcd, std::abs(coeffs[i])); + } + EXPECT_EQ(gcd, min_elem); + return; + } + EXPECT_TRUE(solution.has_solutions); + + // Checks that the particular solution is inside the described solution. + for (int i = 0; i < solution.special_solution.size(); ++i) { + particular_solution[solution.index_permutation[i]] -= + solution.special_solution[i]; + } + int replaced_variable_count = + static_cast(solution.special_solution.size()); + for (int i = num_elements - 1; i >= replaced_variable_count; --i) { + const absl::int128 q = particular_solution[solution.index_permutation[i]]; + particular_solution[solution.index_permutation[i]] = 0; + for (int j = 0; j < solution.kernel_basis[i - 1].size(); ++j) { + particular_solution[solution.index_permutation[j]] -= + q * solution.kernel_basis[i - 1][j]; + } + } + for (int i = replaced_variable_count - 2; i >= 0; --i) { + const int dom_coeff = static_cast(solution.kernel_basis[i][i + 1]); + EXPECT_EQ( + particular_solution[solution.index_permutation[i + 1]] % dom_coeff, 0); + const absl::int128 q = + particular_solution[solution.index_permutation[i + 1]] / dom_coeff; + EXPECT_LE(solution.kernel_vars_lbs[i], q); + EXPECT_LE(q, solution.kernel_vars_ubs[i]); + for (int j = 0; j < solution.kernel_basis[i].size(); ++j) { + particular_solution[solution.index_permutation[j]] -= + q * solution.kernel_basis[i][j]; + } + } + for (const absl::int128 s : particular_solution) { + EXPECT_EQ(s, 0); + } +} + +INSTANTIATE_TEST_SUITE_P(RandomTests, RandomTest, + ::testing::Range(0, DEBUG_MODE ? 1000 : 10000)); + +} // namespace + +} // namespace operations_research::sat diff --git a/ortools/sat/feasibility_pump_test.cc b/ortools/sat/feasibility_pump_test.cc new file mode 100644 index 0000000000..40563f97d6 --- /dev/null +++ b/ortools/sat/feasibility_pump_test.cc @@ -0,0 +1,190 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/feasibility_pump.h" + +#include + +#include "gtest/gtest.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_loader.h" +#include "ortools/sat/cp_model_mapping.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { +namespace { + +constexpr double kTolerance = 1e-6; + +int AddVariable(int lb, int ub, CpModelProto* model) { + const int index = model->variables_size(); + sat::IntegerVariableProto* var = model->add_variables(); + var->add_domain(lb); + var->add_domain(ub); + return index; +} + +TEST(FPTest, SimpleTest) { + Model model; + CpModelProto model_proto; + AddVariable(0, 50, &model_proto); + AddVariable(0, 20, &model_proto); + auto* mapping = model.GetOrCreate(); + LoadVariables(model_proto, false, &model); + IntegerVariable x = mapping->Integer(0); + IntegerVariable y = mapping->Integer(1); + FeasibilityPump* fp = model.Create(); + fp->SetMaxFPIterations(3); + + LinearConstraintBuilder ct(IntegerValue(4), IntegerValue(8)); + ct.AddTerm(x, IntegerValue(2)); + ct.AddTerm(y, IntegerValue(1)); + fp->AddLinearConstraint(ct.Build()); + + fp->SetObjectiveCoefficient(x, IntegerValue(6)); + fp->SetObjectiveCoefficient(y, IntegerValue(3)); + + EXPECT_TRUE(fp->Solve()); + + EXPECT_TRUE(fp->HasLPSolution()); + EXPECT_NEAR(12.0, fp->LPSolutionObjectiveValue(), kTolerance); + EXPECT_TRUE(fp->LPSolutionIsInteger()); +} + +TEST(FPTest, InfeasibilityTest) { + Model model; + CpModelProto model_proto; + model.GetOrCreate()->set_fp_rounding( + SatParameters::PROPAGATION_ASSISTED); + AddVariable(0, 2, &model_proto); + auto* mapping = model.GetOrCreate(); + LoadVariables(model_proto, false, &model); + IntegerVariable x = mapping->Integer(0); + FeasibilityPump* fp = model.Create(); + fp->SetMaxFPIterations(3); + + // Note(user): We don't rely on the imprecise LP to report infeasibility. + model.GetOrCreate()->NotifyThatModelIsUnsat(); + // x = 1/4 + LinearConstraintBuilder ct(IntegerValue(1), IntegerValue(1)); + ct.AddTerm(x, IntegerValue(4)); + fp->AddLinearConstraint(ct.Build()); + + fp->SetObjectiveCoefficient(x, IntegerValue(4)); + + EXPECT_FALSE(fp->Solve()); +} + +// int x,y in [0,50] +// min -x -2y +// 2x + 3y <= 12 +// 3x + 2y <= 12 +// -x + y = 0 +// +// LP solution: -7.2 (x = y = 2.4). +// Integer feasible solution: -6 (x = y = 2). +class FeasibilityPumpTest : public testing::Test { + public: + FeasibilityPumpTest() { + AddVariable(0, 50, &model_proto_); + AddVariable(0, 50, &model_proto_); + auto* mapping = model_.GetOrCreate(); + LoadVariables(model_proto_, false, &model_); + IntegerVariable x_ = mapping->Integer(0); + IntegerVariable y_ = mapping->Integer(1); + fp_ = model_.Create(); + + LinearConstraintBuilder ct(kMinIntegerValue, IntegerValue(12)); + ct.AddTerm(x_, IntegerValue(3)); + ct.AddTerm(y_, IntegerValue(2)); + fp_->AddLinearConstraint(ct.Build()); + + LinearConstraintBuilder ct2(kMinIntegerValue, IntegerValue(12)); + ct2.AddTerm(x_, IntegerValue(2)); + ct2.AddTerm(y_, IntegerValue(3)); + fp_->AddLinearConstraint(ct2.Build()); + + LinearConstraintBuilder ct3(IntegerValue(0), IntegerValue(0)); + ct3.AddTerm(x_, IntegerValue(-1)); + ct3.AddTerm(y_, IntegerValue(1)); + fp_->AddLinearConstraint(ct3.Build()); + + fp_->SetObjectiveCoefficient(x_, IntegerValue(-1)); + fp_->SetObjectiveCoefficient(y_, IntegerValue(-2)); + } + + FeasibilityPump* fp_; + Model model_; + CpModelProto model_proto_; + IntegerVariable x_; + IntegerVariable y_; +}; + +TEST_F(FeasibilityPumpTest, SimpleRounding) { + fp_->SetMaxFPIterations(1); + EXPECT_TRUE(fp_->Solve()); + + EXPECT_TRUE(fp_->HasLPSolution()); + EXPECT_NEAR(-7.2, fp_->LPSolutionObjectiveValue(), kTolerance); + EXPECT_FALSE(fp_->LPSolutionIsInteger()); + EXPECT_NEAR(0.4, fp_->LPSolutionFractionality(), kTolerance); + + EXPECT_TRUE(fp_->HasIntegerSolution()); + EXPECT_EQ(-6, fp_->IntegerSolutionObjectiveValue()); + EXPECT_TRUE(fp_->IntegerSolutionIsFeasible()); +} + +TEST_F(FeasibilityPumpTest, MultipleIterations) { + fp_->SetMaxFPIterations(5); + EXPECT_TRUE(fp_->Solve()); + + EXPECT_TRUE(fp_->HasLPSolution()); + EXPECT_NEAR(-6, fp_->LPSolutionObjectiveValue(), kTolerance); + EXPECT_TRUE(fp_->LPSolutionIsInteger()); + + EXPECT_TRUE(fp_->HasIntegerSolution()); + EXPECT_EQ(-6, fp_->IntegerSolutionObjectiveValue()); + EXPECT_TRUE(fp_->IntegerSolutionIsFeasible()); +} + +TEST_F(FeasibilityPumpTest, MultipleCalls) { + EXPECT_TRUE(fp_->Solve()); + + EXPECT_TRUE(fp_->HasLPSolution()); + EXPECT_NEAR(-6, fp_->LPSolutionObjectiveValue(), kTolerance); + EXPECT_TRUE(fp_->LPSolutionIsInteger()); + + EXPECT_TRUE(fp_->HasIntegerSolution()); + EXPECT_EQ(-6, fp_->IntegerSolutionObjectiveValue()); + EXPECT_TRUE(fp_->IntegerSolutionIsFeasible()); + + // Change bounds. + auto* integer_trail = model_.GetOrCreate(); + ASSERT_TRUE(integer_trail->Enqueue( + IntegerLiteral::LowerOrEqual(x_, IntegerValue(1)), {}, {})); + EXPECT_TRUE(fp_->Solve()); + + EXPECT_TRUE(fp_->HasLPSolution()); + EXPECT_NEAR(-3, fp_->LPSolutionObjectiveValue(), kTolerance); + EXPECT_TRUE(fp_->LPSolutionIsInteger()); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/lb_tree_search_test.cc b/ortools/sat/lb_tree_search_test.cc new file mode 100644 index 0000000000..a67d03e0a5 --- /dev/null +++ b/ortools/sat/lb_tree_search_test.cc @@ -0,0 +1,38 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/lb_tree_search.h" + +#include "gtest/gtest.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/cp_model_test_utils.h" +#include "ortools/sat/sat_parameters.pb.h" + +namespace operations_research { +namespace sat { +namespace { + +// This just check that the code compile and runs. +TEST(LbTreeSearch, BooleanLinearOptimizationProblem) { + const CpModelProto model_proto = RandomLinearProblem(50, 50); + SatParameters params; + params.set_optimize_with_lb_tree_search(true); + params.set_log_search_progress(true); + const CpSolverResponse response = SolveWithParameters(model_proto, params); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/linear_programming_constraint.cc b/ortools/sat/linear_programming_constraint.cc index b06a628fbc..12d269c568 100644 --- a/ortools/sat/linear_programming_constraint.cc +++ b/ortools/sat/linear_programming_constraint.cc @@ -106,57 +106,6 @@ bool ScatteredIntegerVector::Add(glop::ColIndex col, IntegerValue value) { return true; } -template -bool ScatteredIntegerVector::AddLinearExpressionMultiple( - const IntegerValue multiplier, absl::Span cols, - absl::Span coeffs, IntegerValue max_coeff_magnitude) { - // Since we have the norm, this avoid checking each products below. - if (check_overflow) { - const IntegerValue prod = CapProdI(max_coeff_magnitude, multiplier); - if (AtMinOrMaxInt64(prod.value())) return false; - } - - IntegerValue* data = dense_vector_.data(); - const double threshold = 0.1 * static_cast(dense_vector_.size()); - const int num_terms = cols.size(); - if (is_sparse_ && static_cast(num_terms) < threshold) { - for (int i = 0; i < num_terms; ++i) { - const glop::ColIndex col = cols[i]; - if (is_zeros_[col]) { - is_zeros_[col] = false; - non_zeros_.push_back(col); - } - const IntegerValue product = multiplier * coeffs[i]; - if (check_overflow) { - if (AddIntoOverflow(product.value(), - data[col.value()].mutable_value())) { - return false; - } - } else { - data[col.value()] += product; - } - } - if (static_cast(non_zeros_.size()) > threshold) { - is_sparse_ = false; - } - } else { - is_sparse_ = false; - for (int i = 0; i < num_terms; ++i) { - const glop::ColIndex col = cols[i]; - const IntegerValue product = multiplier * coeffs[i]; - if (check_overflow) { - if (AddIntoOverflow(product.value(), - data[col.value()].mutable_value())) { - return false; - } - } else { - data[col.value()] += product; - } - } - } - return true; -} - LinearConstraint ScatteredIntegerVector::ConvertToLinearConstraint( absl::Span integer_variables, IntegerValue upper_bound, diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index 093e68de9a..cbc40391b6 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -74,7 +74,53 @@ class ScatteredIntegerVector { bool AddLinearExpressionMultiple(IntegerValue multiplier, absl::Span cols, absl::Span coeffs, - IntegerValue max_coeff_magnitude); + IntegerValue max_coeff_magnitude) { + // Since we have the norm, this avoid checking each products below. + if (check_overflow) { + const IntegerValue prod = CapProdI(max_coeff_magnitude, multiplier); + if (AtMinOrMaxInt64(prod.value())) return false; + } + + IntegerValue* data = dense_vector_.data(); + const double threshold = 0.1 * static_cast(dense_vector_.size()); + const int num_terms = cols.size(); + if (is_sparse_ && static_cast(num_terms) < threshold) { + for (int i = 0; i < num_terms; ++i) { + const glop::ColIndex col = cols[i]; + if (is_zeros_[col]) { + is_zeros_[col] = false; + non_zeros_.push_back(col); + } + const IntegerValue product = multiplier * coeffs[i]; + if (check_overflow) { + if (AddIntoOverflow(product.value(), + data[col.value()].mutable_value())) { + return false; + } + } else { + data[col.value()] += product; + } + } + if (static_cast(non_zeros_.size()) > threshold) { + is_sparse_ = false; + } + } else { + is_sparse_ = false; + for (int i = 0; i < num_terms; ++i) { + const glop::ColIndex col = cols[i]; + const IntegerValue product = multiplier * coeffs[i]; + if (check_overflow) { + if (AddIntoOverflow(product.value(), + data[col.value()].mutable_value())) { + return false; + } + } else { + data[col.value()] += product; + } + } + } + return true; +} // This is not const only because non_zeros is sorted. Note that sorting the // non-zeros make the result deterministic whether or not we were in sparse diff --git a/ortools/sat/linear_programming_constraint_test.cc b/ortools/sat/linear_programming_constraint_test.cc new file mode 100644 index 0000000000..73ec064503 --- /dev/null +++ b/ortools/sat/linear_programming_constraint_test.cc @@ -0,0 +1,309 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/linear_programming_constraint.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/random/random.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/mathutil.h" +#include "ortools/lp_data/lp_types.h" +#include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/linear_constraint_manager.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(ScatteredIntegerVectorTest, BasicDenseBehavior) { + ScatteredIntegerVector v; + v.ClearAndResize(10); + v.AddLinearExpressionMultiple(IntegerValue(4), {glop::ColIndex(2)}, + {IntegerValue(3)}, IntegerValue(3)); + v.AddLinearExpressionMultiple(IntegerValue(3), {glop::ColIndex(1)}, + {IntegerValue(3)}, IntegerValue(3)); + v.AddLinearExpressionMultiple(IntegerValue(5), {glop::ColIndex(2)}, + {IntegerValue(3)}, IntegerValue(3)); + const std::vector> expected{ + {glop::ColIndex(1), IntegerValue(3 * 3)}, + {glop::ColIndex(2), IntegerValue(3 * 4 + 5 * 3)}}; + EXPECT_FALSE(v.IsSparse()); + EXPECT_EQ(v.GetTerms(), expected); +} + +TEST(ScatteredIntegerVectorTest, BasicSparseBehavior) { + ScatteredIntegerVector v; + v.ClearAndResize(100000); + v.AddLinearExpressionMultiple(IntegerValue(4), {glop::ColIndex(2)}, + {IntegerValue(3)}, IntegerValue(3)); + v.AddLinearExpressionMultiple(IntegerValue(3), {glop::ColIndex(1)}, + {IntegerValue(3)}, IntegerValue(3)); + v.AddLinearExpressionMultiple(IntegerValue(5), {glop::ColIndex(2)}, + {IntegerValue(3)}, IntegerValue(3)); + const std::vector> expected{ + {glop::ColIndex(1), IntegerValue(3 * 3)}, + {glop::ColIndex(2), IntegerValue(3 * 4 + 5 * 3)}}; + EXPECT_EQ(v.GetTerms(), expected); + EXPECT_TRUE(v.IsSparse()); +} + +// TODO(user): Check that SAT solutions respect linear equations. +struct LPProblem { + const std::vector integer_lb; + const std::vector integer_ub; + const std::vector constraint_lb; + const std::vector constraint_ub; + const std::vector> constraint_integer_indices; + const std::vector> constraint_integer_coefs; + const std::vector objective_indices; + + int num_integer_vars() const { return integer_lb.size(); } + int num_constraints() const { return constraint_lb.size(); } + int num_objectives() const { return objective_indices.size(); } +}; + +// Generates a problem that encodes a permutation as a bipartite matching +// from size 'left' nodes to size 'right' nodes. +// Decision variables are edges linking left nodes to right nodes. +// For a given node, only one adjacent edge can be present, +// which is encoded by one constraint per node. +LPProblem GeneratePermutationProblem(int size) { + const std::vector edge_lb(size * size, IntegerValue(0)); + const std::vector edge_ub(size * size, IntegerValue(1)); + + const std::vector node_lb(2 * size, IntegerValue(1)); + const std::vector node_ub(2 * size, IntegerValue(1)); + + std::vector> node_constraint_indices; + std::vector> node_constraint_coefs; + + // Left and right nodes are indexed by [0, size). + // The edge (left, right) has number left * size + right. + const std::vector ones(size, IntegerValue(1)); + for (int left = 0; left < size; left++) { + std::vector indices; + for (int right = 0; right < size; right++) { + indices.push_back(left * size + right); + } + node_constraint_indices.push_back(indices); + node_constraint_coefs.push_back(ones); + } + + for (int right = 0; right < size; right++) { + std::vector indices; + for (int left = 0; left < size; left++) { + indices.push_back(left * size + right); + } + node_constraint_indices.push_back(indices); + node_constraint_coefs.push_back(ones); + } + + return LPProblem{edge_lb, + edge_ub, + node_lb, + node_ub, + node_constraint_indices, + node_constraint_coefs, + {}}; +} + +int CountSolutionsOfLPProblemUsingSAT(const LPProblem& problem) { + Model model; + + std::vector cp_variables; + const int num_cp_vars = problem.num_integer_vars(); + for (int i = 0; i < num_cp_vars; i++) { + IntegerVariable var = model.Add(NewIntegerVariable( + problem.integer_lb[i].value(), problem.integer_ub[i].value())); + cp_variables.push_back(var); + } + + LinearProgrammingConstraint* lp = + new LinearProgrammingConstraint(&model, cp_variables); + model.TakeOwnership(lp); + + const int num_constraints = problem.num_constraints(); + for (int c = 0; c < num_constraints; c++) { + LinearConstraint ct(problem.constraint_lb[c], problem.constraint_ub[c]); + const int num_integer = problem.constraint_integer_indices[c].size(); + ct.resize(num_integer); + for (int j = 0; j < num_integer; j++) { + ct.vars[j] = cp_variables[problem.constraint_integer_indices[c][j]]; + ct.coeffs[j] = problem.constraint_integer_coefs[c][j]; + } + lp->AddLinearConstraint(std::move(ct)); + } + + lp->RegisterWith(&model); + + int num_solutions = 0; + while (SolveIntegerProblemWithLazyEncoding(&model) == + SatSolver::Status::FEASIBLE) { + model.Add(ExcludeCurrentSolutionAndBacktrack()); + num_solutions++; + } + return num_solutions; +} + +TEST(LinearProgrammingConstraintTest, CountPermutations) { + int factorial_of_size = 1; + for (int size = 2; size < 6; size++) { + factorial_of_size *= size; + LPProblem problem = GeneratePermutationProblem(size); + ASSERT_EQ(CountSolutionsOfLPProblemUsingSAT(problem), factorial_of_size); + } +} + +TEST(LinearProgrammingConstraintTest, SimpleInfeasibility) { + // The following flow is infeasible, LP should detect it. + // + // source + // = 2/2 0/2 + // ---------> ---------> A + // | | + // | | + // 0/2 | | 0/2 + // | | + // V V + // B -------> -------> + // 0/2 [3,4] = sink + Model model; + + // We need this parameter at false to detect it on the first propagation. + model.Add(NewSatParameters("add_lp_constraints_lazily:false")); + + // Variables for the source and sink demands, and flow amounts. + const IntegerVariable source = model.Add(NewIntegerVariable(2, 2)); + const IntegerVariable source_a = model.Add(NewIntegerVariable(0, 2)); + const IntegerVariable source_b = model.Add(NewIntegerVariable(0, 2)); + const IntegerVariable a_sink = model.Add(NewIntegerVariable(0, 2)); + const IntegerVariable b_sink = model.Add(NewIntegerVariable(0, 2)); + const IntegerVariable sink = model.Add(NewIntegerVariable(3, 4)); + + // LP Constraint and flow conservation equalities. + LinearProgrammingConstraint* lp = new LinearProgrammingConstraint( + &model, {source, source_a, source_b, a_sink, b_sink, sink}); + model.TakeOwnership(lp); + + LinearConstraintBuilder ct_source(IntegerValue(0), IntegerValue(0)); + ct_source.AddTerm(source, IntegerValue(1)); + ct_source.AddTerm(source_a, IntegerValue(-1)); + ct_source.AddTerm(source_b, IntegerValue(-1)); + lp->AddLinearConstraint(ct_source.Build()); + + LinearConstraintBuilder ct_a(IntegerValue(0), IntegerValue(0)); + ct_a.AddTerm(source_a, IntegerValue(1)); + ct_a.AddTerm(a_sink, IntegerValue(-1)); + lp->AddLinearConstraint(ct_a.Build()); + + LinearConstraintBuilder ct_b(IntegerValue(0), IntegerValue(0)); + ct_b.AddTerm(source_b, IntegerValue(1)); + ct_b.AddTerm(b_sink, IntegerValue(-1)); + lp->AddLinearConstraint(ct_b.Build()); + + LinearConstraintBuilder ct_sink(IntegerValue(0), IntegerValue(0)); + ct_sink.AddTerm(a_sink, IntegerValue(1)); + ct_sink.AddTerm(b_sink, IntegerValue(1)); + ct_sink.AddTerm(sink, IntegerValue(-1)); + lp->AddLinearConstraint(ct_sink.Build()); + + lp->RegisterWith(&model); + + ASSERT_FALSE(model.GetOrCreate()->Propagate()); +} + +TEST(LinearProgrammingConstraintTest, EmptyLP) { + Model model; + model.Add(NewSatParameters("linearization_level:2")); + + LinearProgrammingConstraint* lp = new LinearProgrammingConstraint(&model, {}); + model.TakeOwnership(lp); + lp->RegisterWith(&model); + + ASSERT_TRUE(model.GetOrCreate()->Propagate()); +} + +// This tests that the scaling of reduced costs is done correctly, +// using the problem +// min 64 x s.t. 0 <= 27x <= 81, x in [0, 50] +// The reduced cost of x should be 64. +// This will be wrong if the reduced cost ignores the objective scaling. +TEST(LinearProgrammingConstraintTest, ReducedCostScalingObjective) { + Model m; + IntegerVariable x = m.Add(NewIntegerVariable(0, 50)); + LinearProgrammingConstraint* lp = new LinearProgrammingConstraint(&m, {x}); + m.TakeOwnership(lp); + + LinearConstraintBuilder ct(IntegerValue(0), IntegerValue(81)); + ct.AddTerm(x, IntegerValue(27)); + lp->AddLinearConstraint(ct.Build()); + + IntegerVariable obj = m.Add(NewIntegerVariable(0, 64 * 50)); + lp->SetObjectiveCoefficient(x, IntegerValue(64)); + lp->SetMainObjectiveVariable(obj); + + lp->RegisterWith(&m); + lp->Propagate(); + const auto& reduced_costs = *m.GetOrCreate(); + CHECK_LE(std::abs(reduced_costs[x] - 64), 1e-6); +} + +// This tests that the scaling of reduced costs is done correctly, +// using the problem +// min 64 x + 32 y s.t. 0 <= 27x + 9y <= 81, x in [0, 50], y in [0, 20] +// The reduced cost of x should be 64, y should be 20. +// This will be wrong if the reduced cost ignores the column scaling. +TEST(LinearProgrammingConstraintTest, ReducedCostScalingColumns) { + Model m; + IntegerVariable x = m.Add(NewIntegerVariable(0, 50)); + IntegerVariable y = m.Add(NewIntegerVariable(0, 20)); + LinearProgrammingConstraint* lp = new LinearProgrammingConstraint(&m, {x, y}); + m.TakeOwnership(lp); + + LinearConstraintBuilder ct(IntegerValue(0), IntegerValue(81)); + ct.AddTerm(x, IntegerValue(27)); + ct.AddTerm(y, IntegerValue(9)); + lp->AddLinearConstraint(ct.Build()); + + IntegerVariable obj = m.Add(NewIntegerVariable(0, 64 * 50 + 32 * 20)); + lp->SetObjectiveCoefficient(x, IntegerValue(64)); + lp->SetObjectiveCoefficient(y, IntegerValue(32)); + lp->SetMainObjectiveVariable(obj); + + lp->RegisterWith(&m); + lp->Propagate(); + const auto& reduced_costs = *m.GetOrCreate(); + CHECK_LE(std::abs(reduced_costs[x] - 64), 1e-6); + CHECK_LE(std::abs(reduced_costs[y] - 32), 1e-6); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 9c8bccbd86..31e43a8280 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -2569,145 +2569,5 @@ ConstraintProto* PresolveContext::NewMappingConstraint( return new_ct; } -void CanonicalizeTable(PresolveContext* context, ConstraintProto* ct) { - if (context->ModelIsUnsat()) return; - - const int num_exprs = ct->table().exprs_size(); - const int num_tuples = ct->table().values_size() / num_exprs; - - // Detect expressions sharing the same variable as a previous expression. - absl::flat_hash_map var_to_position; - - // The mapping between the position in the original list of expressions, and - // the position in the reduced list of expressions. - std::vector> position_mapping(num_exprs, std::nullopt); - int num_shared_vars = 0; - for (int i = 0; i < num_exprs; ++i) { - const LinearExpressionProto& expr = ct->table().exprs(i); - if (context->IsFixed(expr)) continue; - - const int var = expr.vars(0); - const auto [it, inserted] = - var_to_position.insert({var, var_to_position.size()}); - if (!inserted) { - ++num_shared_vars; - position_mapping[i] = it->second; - } - } - - const int num_kept_exprs = num_exprs - num_shared_vars; - - std::vector> new_tuples; - new_tuples.reserve(num_tuples); - - std::vector new_scaled_values; - new_scaled_values.reserve(num_kept_exprs); - - for (int t = 0; t < num_tuples; ++t) { - bool tuple_is_valid = true; - new_scaled_values.clear(); - - for (int e = 0; e < num_exprs; ++e) { - const int64_t value = ct->table().values(t * num_exprs + e); - const LinearExpressionProto& expr = ct->table().exprs(e); - if (context->IsFixed(expr)) { - if (value != context->FixedValue(expr)) { - tuple_is_valid = false; - break; - } - new_scaled_values.push_back(value); - } else if (position_mapping[e].has_value()) { - const int var_first_position = position_mapping[e].value(); - const int64_t var_value = new_scaled_values[var_first_position]; - const int64_t forced_value = AffineExpressionValueAt(expr, var_value); - if (value != forced_value) { - tuple_is_valid = false; - break; - } - } else { - if (!context->DomainContains(expr, value)) { - tuple_is_valid = false; - break; - } - new_scaled_values.push_back(GetInnerVarValue(expr, value)); - } - } - - if (tuple_is_valid) { - DCHECK_EQ(new_scaled_values.size(), num_kept_exprs); - new_tuples.push_back(new_scaled_values); - } - } - - // Remove all scaling on expressions as we have stored the inner values. - for (int e = 0; e < num_exprs; ++e) { - if (position_mapping[e].has_value()) continue; - if (context->IsFixed(ct->table().exprs(e))) continue; - DCHECK_EQ(ct->table().exprs(e).coeffs_size(), 1); - ct->mutable_table()->mutable_exprs(e)->set_offset(0); - ct->mutable_table()->mutable_exprs(e)->set_coeffs(0, 1); - } - - if (num_kept_exprs < num_exprs) { - int index = 0; - for (int e = 0; e < num_exprs; ++e) { - if (position_mapping[e].has_value()) continue; - ct->mutable_table()->mutable_exprs()->SwapElements(index++, e); - } - CHECK_EQ(index, num_kept_exprs); - ct->mutable_table()->mutable_exprs()->DeleteSubrange(index, - num_exprs - index); - context->UpdateRuleStats("table: remove expressions"); - } - - gtl::STLSortAndRemoveDuplicates(&new_tuples); - if (new_tuples.size() < num_tuples) { - context->UpdateRuleStats("table: remove tuples"); - } - - // Write sorted tuples. - ct->mutable_table()->clear_values(); - for (const std::vector& tuple : new_tuples) { - ct->mutable_table()->mutable_values()->Add(tuple.begin(), tuple.end()); - } -} - -void RemoveFixedColumnsFromTable(PresolveContext* context, - ConstraintProto* ct) { - if (context->ModelIsUnsat()) return; - const int num_exprs = ct->table().exprs_size(); - const int num_tuples = ct->table().values_size() / num_exprs; - std::vector is_fixed(num_exprs, false); - int num_fixed_exprs = 0; - for (int e = 0; e < num_exprs; ++e) { - is_fixed[e] = context->IsFixed(ct->table().exprs(e)); - num_fixed_exprs += is_fixed[e]; - } - if (num_fixed_exprs == 0) return; - - int num_kept_exprs = num_exprs - num_fixed_exprs; - - int index = 0; - for (int e = 0; e < num_exprs; ++e) { - if (is_fixed[e]) continue; - ct->mutable_table()->mutable_exprs()->SwapElements(index++, e); - } - CHECK_EQ(index, num_kept_exprs); - ct->mutable_table()->mutable_exprs()->DeleteSubrange(index, - num_exprs - index); - index = 0; - for (int t = 0; t < num_tuples; ++t) { - for (int e = 0; e < num_exprs; ++e) { - if (is_fixed[e]) continue; - ct->mutable_table()->set_values(index++, - ct->table().values(t * num_exprs + e)); - } - } - CHECK_EQ(index, num_tuples * num_kept_exprs); - ct->mutable_table()->mutable_values()->Truncate(index); - - context->UpdateRuleStats("table: remove fixed columns"); -} - } // namespace sat } // namespace operations_research diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index 98bbaa0fc7..a2adcd2db5 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -15,7 +15,6 @@ #define OR_TOOLS_SAT_PRESOLVE_CONTEXT_H_ #include -#include #include #include #include @@ -288,7 +287,7 @@ class PresolveContext { // At the beginning of the presolve, we delay the costly creation of this // "graph" until we at least ran some basic presolve. This is because during - // a LNS neighbhorhood, many constraints will be reduced significantly by + // a LNS neighborhood, many constraints will be reduced significantly by // this "simple" presolve. bool ConstraintVariableGraphIsUpToDate() const; @@ -461,7 +460,7 @@ class PresolveContext { bool RecomputeSingletonObjectiveDomain(); // Some function need the domain to be up to date in the proto. - // This make sures our in-memory domain are writted back to the proto. + // This make sures our in-memory domain are written back to the proto. void WriteVariableDomainsToProto() const; // Checks if the given exactly_one is included in the objective, and simplify @@ -774,15 +773,6 @@ class PresolveContext { // that will be used for probing. Returns false if UNSAT. bool LoadModelForProbing(PresolveContext* context, Model* local_model); -// Canonicalizes the table constraint by removing all unreachable tuples, and -// all columns which have the same variable of a previous column. -// -// This also sort all the tuples. -void CanonicalizeTable(PresolveContext* context, ConstraintProto* ct); - -// Removed all fixed columns from the table. -void RemoveFixedColumnsFromTable(PresolveContext* context, ConstraintProto* ct); - } // namespace sat } // namespace operations_research diff --git a/ortools/sat/presolve_util.h b/ortools/sat/presolve_util.h index d584a68d16..c663ea8909 100644 --- a/ortools/sat/presolve_util.h +++ b/ortools/sat/presolve_util.h @@ -14,7 +14,6 @@ #ifndef OR_TOOLS_SAT_PRESOLVE_UTIL_H_ #define OR_TOOLS_SAT_PRESOLVE_UTIL_H_ -#include #include #include #include @@ -24,12 +23,9 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/random/bit_gen_ref.h" -#include "absl/random/random.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "ortools/base/logging.h" #include "ortools/base/strong_vector.h" -#include "ortools/base/types.h" +#include "ortools/base/timer.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/util.h" @@ -86,7 +82,7 @@ class PresolveTimer { // If for each literal of a clause, we can infer a domain on an integer // variable, then we know that this variable domain is included in the union of -// such infered domains. +// such inferred domains. // // This allows to propagate "element" like constraints encoded as enforced // linear relations, and other more general reasoning. diff --git a/ortools/sat/sat_cnf_reader_test.cc b/ortools/sat/sat_cnf_reader_test.cc new file mode 100644 index 0000000000..d04daf3de9 --- /dev/null +++ b/ortools/sat/sat_cnf_reader_test.cc @@ -0,0 +1,168 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/sat_cnf_reader.h" + +#include + +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gtest/gtest.h" +#include "ortools/base/helpers.h" +#include "ortools/base/options.h" +#include "ortools/base/path.h" +#include "ortools/sat/boolean_problem.h" +#include "ortools/sat/boolean_problem.pb.h" + +namespace operations_research { +namespace sat { +namespace { + +std::string WriteTmpFileOrDie(absl::string_view content) { + static int counter = 0; + const std::string filename = file::JoinPath( + ::testing::TempDir(), absl::StrCat("file_", counter++, ".cnf")); + CHECK_OK(file::SetContents(filename, content, file::Defaults())); + return filename; +} + +TEST(SatCnfReader, CnfFormat) { + std::string file_content = + "p cnf 5 4\n" + "+1 +2 +3 0\n" + "-4 -5 0\n" + "+1 0\n" + "-1 0\n"; + SatCnfReader reader; + LinearBooleanProblem problem; + EXPECT_TRUE(reader.Load(WriteTmpFileOrDie(file_content), &problem)); + EXPECT_EQ(file_content, LinearBooleanProblemToCnfString(problem)); +} + +TEST(SatCnfReader, CnfFormatAsMaxSat) { + const std::string file_content = + "p cnf 5 4\n" + "+1 +2 +3 0\n" + "-4 -5 0\n" + "+1 0\n" + "-1 0\n"; + SatCnfReader reader(/*wcnf_use_strong_slack=*/false); + reader.InterpretCnfAsMaxSat(true); + LinearBooleanProblem problem; + EXPECT_TRUE(reader.Load(WriteTmpFileOrDie(file_content), &problem)); + + // Note that we currently loose the objective offset of 1 due to the two + // clauses +1 and -1 in the original problem. + const std::string wcnf_output = + "p wcnf 5 2 3\n" + "1 +1 +2 +3 0\n" + "1 -4 -5 0\n"; + EXPECT_EQ(wcnf_output, LinearBooleanProblemToCnfString(problem)); +} + +TEST(SatCnfReader, CnfFormatCornerCases) { + std::string file_content = + "p cnf 5 4\n" + "+1 +2 +3 0\n" + "-4 -5 0\n" + "+1 0\n" + "-1 0\n"; + std::string file_content_with_comments = + "c comments are ignored\n" + "p cnf 5 4\n" + "c + are not mandatory: \n" + "+1 2 +3 0\n" + "c and can be anywhere, with 0 0 0\n" + "-4 -5 0\n" + "c empty line are ignored\n" + "\n\n\n \n" + "+1 0\n" + "c same for spaces:\n" + " -1 0\n"; + SatCnfReader reader; + LinearBooleanProblem problem; + EXPECT_TRUE( + reader.Load(WriteTmpFileOrDie(file_content_with_comments), &problem)); + EXPECT_EQ(file_content, LinearBooleanProblemToCnfString(problem)); +} + +TEST(SatCnfReader, ClausesNumberMustMatch) { + std::string file_content = + "p cnf 5 4\n" + "+1 +2 +3 0\n" + "-4 -5 0\n" + "+1 0\n" + "0\n" + "-1 0\n"; + SatCnfReader reader; + LinearBooleanProblem problem; + + // Note that we changed that requirement since now we dynamically infer sizes. + // We just log errors. + EXPECT_TRUE(reader.Load(WriteTmpFileOrDie(file_content), &problem)); +} + +TEST(SatCnfReader, WcnfFormat) { + // Note that the input format is such that it is the same as the one produced + // by LinearBooleanProblemToCnfString(). + // + // The special hard weight "109" is by convention the sum of all the soft + // weight + 1. It means that not satisfying an hard clause is worse than + // satisfying none of the soft clause. Note that this is just a "convention", + // in the way we interpret it, it doesn't really matter and must just be a + // different number than any of the soft weight. + std::string file_content = + "p wcnf 5 7 109\n" + "1 +1 +2 +3 0\n" + "2 -4 -5 0\n" + "109 -1 0\n" + "109 +1 0\n" + "99 +1 0\n" + "3 +4 0\n" + "3 +5 0\n"; + SatCnfReader reader(/*wcnf_use_strong_slack=*/false); + LinearBooleanProblem problem; + EXPECT_TRUE(reader.Load(WriteTmpFileOrDie(file_content), &problem)); + EXPECT_EQ(4, problem.constraints_size()); + EXPECT_EQ(file_content, LinearBooleanProblemToCnfString(problem)); +} + +TEST(SatCnfReader, WcnfNewFormat) { + const std::string new_format = + "1 +1 +2 +3 0\n" + "2 -4 -5 0\n" + "h -1 0\n" + "h +1 0\n" + "99 +1 0\n" + "3 +4 0\n" + "3 +5 0\n"; + const std::string file_content = + "p wcnf 5 7 109\n" + "1 +1 +2 +3 0\n" + "2 -4 -5 0\n" + "109 -1 0\n" + "109 +1 0\n" + "99 +1 0\n" + "3 +4 0\n" + "3 +5 0\n"; + SatCnfReader reader(/*wcnf_use_strong_slack=*/false); + LinearBooleanProblem problem; + EXPECT_TRUE(reader.Load(WriteTmpFileOrDie(new_format), &problem)); + EXPECT_EQ(4, problem.constraints_size()); + EXPECT_EQ(file_content, LinearBooleanProblemToCnfString(problem)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/sat_decision_test.cc b/ortools/sat/sat_decision_test.cc new file mode 100644 index 0000000000..2491fc17bd --- /dev/null +++ b/ortools/sat/sat_decision_test.cc @@ -0,0 +1,101 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/sat_decision.h" + +#include +#include +#include + +#include "absl/random/random.h" +#include "gtest/gtest.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(SatDecisionPolicyTest, ExternalPreferences) { + Model model; + Trail* trail = model.GetOrCreate(); + SatDecisionPolicy* decision = model.GetOrCreate(); + + const int num_variables = 1000; + trail->Resize(num_variables); + decision->IncreaseNumVariables(num_variables); + + // Generate some arbitrary priorities (all different). + std::vector> var_with_preference( + num_variables); + for (int i = 0; i < num_variables; ++i) { + var_with_preference[i] = {BooleanVariable(i), + static_cast(i) / num_variables}; + } + + absl::BitGen random; + // Add them in random order. + std::shuffle(var_with_preference.begin(), var_with_preference.end(), random); + for (const auto p : var_with_preference) { + const Literal literal(p.first, true); + decision->SetAssignmentPreference(literal, p.second); + } + + // Expect them in decreasing order. + std::sort(var_with_preference.begin(), var_with_preference.end()); + std::reverse(var_with_preference.begin(), var_with_preference.end()); + for (int i = 0; i < num_variables; ++i) { + const Literal literal(var_with_preference[i].first, true); + EXPECT_EQ(literal, decision->NextBranch()); + trail->EnqueueSearchDecision(literal); + } +} + +TEST(SatDecisionPolicyTest, ErwaHeuristic) { + Model model; + auto* params = model.GetOrCreate(); + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* decision = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + params->set_use_erwa_heuristic(true); + decision->ResetDecisionHeuristic(); + + // Default. + EXPECT_EQ(Literal(BooleanVariable(0), false), decision->NextBranch()); + + // Do not return assigned decision. + // Note that the priority queue tie-breaking is not in order. + sat_solver->EnqueueDecisionIfNotConflicting( + Literal(BooleanVariable(0), true)); + EXPECT_EQ(Literal(BooleanVariable(9), false), decision->NextBranch()); + + // Lets enqueue some more. + trail->EnqueueWithUnitReason(Literal(BooleanVariable(1), false)); + trail->EnqueueWithUnitReason(Literal(BooleanVariable(2), true)); + + // We can Bump the reason and simulate a conflict. + decision->BumpVariableActivities({Literal(BooleanVariable(2), true)}); + decision->BeforeConflict(trail->Index()); + sat_solver->Backtrack(0); + + // Now this bumped variable is first, with polarity as assigned. + EXPECT_EQ(Literal(BooleanVariable(2), true), decision->NextBranch()); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/util.cc b/ortools/sat/util.cc index 8bc821226e..ff16eb2801 100644 --- a/ortools/sat/util.cc +++ b/ortools/sat/util.cc @@ -143,7 +143,7 @@ void QuotientAndRemainder(int64_t a, int64_t b, int64_t& q, int64_t& r) { } // namespace -// Using the extended Euclidian algo, we find a and b such that +// Using the extended Euclidean algo, we find a and b such that // a x + b m = gcd(x, m) // https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm int64_t ModularInverse(int64_t x, int64_t m) { @@ -451,44 +451,6 @@ double Percentile::GetPercentile(double percent) { return *lower_it + (percentile_rank - lower_rank) * (*upper_it - *lower_it); } -void CompressTuples(absl::Span domain_sizes, - std::vector>* tuples) { - if (tuples->empty()) return; - - // Remove duplicates if any. - gtl::STLSortAndRemoveDuplicates(tuples); - - const int num_vars = (*tuples)[0].size(); - - std::vector to_remove; - std::vector tuple_minus_var_i(num_vars - 1); - for (int i = 0; i < num_vars; ++i) { - const int domain_size = domain_sizes[i]; - if (domain_size == 1) continue; - absl::flat_hash_map, std::vector> - masked_tuples_to_indices; - for (int t = 0; t < tuples->size(); ++t) { - int out = 0; - for (int j = 0; j < num_vars; ++j) { - if (i == j) continue; - tuple_minus_var_i[out++] = (*tuples)[t][j]; - } - masked_tuples_to_indices[tuple_minus_var_i].push_back(t); - } - to_remove.clear(); - for (const auto& it : masked_tuples_to_indices) { - if (it.second.size() != domain_size) continue; - (*tuples)[it.second.front()][i] = kTableAnyValue; - to_remove.insert(to_remove.end(), it.second.begin() + 1, it.second.end()); - } - std::sort(to_remove.begin(), to_remove.end(), std::greater()); - for (const int t : to_remove) { - (*tuples)[t] = tuples->back(); - tuples->pop_back(); - } - } -} - void MaxBoundedSubsetSum::Reset(int64_t bound) { DCHECK_GE(bound, 0); gcd_ = 0; @@ -776,117 +738,6 @@ BasicKnapsackSolver::Result BasicKnapsackSolver::InternalSolve( namespace { -// We will call FullyCompressTuplesRecursive() for a set of prefixes of the -// original tuples, each having the same suffix (in reversed_suffix). -// -// For such set, we will compress it on the last variable of the prefixes. We -// will then for each unique compressed set of value of that variable, call -// a new FullyCompressTuplesRecursive() on the corresponding subset. -void FullyCompressTuplesRecursive( - absl::Span domain_sizes, - absl::Span> tuples, - std::vector>* reversed_suffix, - std::vector>>* output) { - struct TempData { - absl::InlinedVector values; - int index; - - bool operator<(const TempData& other) const { - return values < other.values; - } - }; - std::vector temp_data; - - CHECK(!tuples.empty()); - CHECK(!tuples[0].empty()); - const int64_t domain_size = domain_sizes[tuples[0].size() - 1]; - - // Sort tuples and regroup common prefix in temp_data. - std::sort(tuples.begin(), tuples.end()); - for (int i = 0; i < tuples.size();) { - const int start = i; - temp_data.push_back({{tuples[start].back()}, start}); - tuples[start].pop_back(); - for (++i; i < tuples.size(); ++i) { - const int64_t v = tuples[i].back(); - tuples[i].pop_back(); - if (tuples[i] == tuples[start]) { - temp_data.back().values.push_back(v); - } else { - tuples[i].push_back(v); - break; - } - } - - // If one of the value is the special value kTableAnyValue, we convert - // it to the "empty means any" format. - for (const int64_t v : temp_data.back().values) { - if (v == kTableAnyValue) { - temp_data.back().values.clear(); - break; - } - } - gtl::STLSortAndRemoveDuplicates(&temp_data.back().values); - - // If values cover the whole domain, we clear the vector. This allows to - // use less space and avoid creating uneeded clauses. - if (temp_data.back().values.size() == domain_size) { - temp_data.back().values.clear(); - } - } - - if (temp_data.size() == 1) { - output->push_back({}); - for (const int64_t v : tuples[temp_data[0].index]) { - if (v == kTableAnyValue) { - output->back().push_back({}); - } else { - output->back().push_back({v}); - } - } - output->back().push_back(temp_data[0].values); - for (int i = reversed_suffix->size(); --i >= 0;) { - output->back().push_back((*reversed_suffix)[i]); - } - return; - } - - // Sort temp_data and make recursive call for all tuples that share the - // same suffix. - std::sort(temp_data.begin(), temp_data.end()); - std::vector> temp_tuples; - for (int i = 0; i < temp_data.size();) { - reversed_suffix->push_back(temp_data[i].values); - const int start = i; - temp_tuples.clear(); - for (; i < temp_data.size(); i++) { - if (temp_data[start].values != temp_data[i].values) break; - temp_tuples.push_back(tuples[temp_data[i].index]); - } - FullyCompressTuplesRecursive(domain_sizes, absl::MakeSpan(temp_tuples), - reversed_suffix, output); - reversed_suffix->pop_back(); - } -} - -} // namespace - -// TODO(user): We can probably reuse the tuples memory always and never create -// new one. We should also be able to code an iterative version of this. Note -// however that the recursion level is bounded by the number of coluns which -// should be small. -std::vector>> FullyCompressTuples( - absl::Span domain_sizes, - std::vector>* tuples) { - std::vector> reversed_suffix; - std::vector>> output; - FullyCompressTuplesRecursive(domain_sizes, absl::MakeSpan(*tuples), - &reversed_suffix, &output); - return output; -} - -namespace { - class CliqueDecomposition { public: CliqueDecomposition(const std::vector>& graph, diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 36d9e8d996..fa10cdc84c 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -15,7 +15,6 @@ #define OR_TOOLS_SAT_UTIL_H_ #include -#include #include #include #include @@ -27,7 +26,6 @@ #include "absl/base/macros.h" #include "absl/container/btree_set.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log_streamer.h" #include "absl/numeric/int128.h" @@ -601,38 +599,6 @@ class Percentile { const int record_limit_; }; -// This method tries to compress a list of tuples by merging complementary -// tuples, that is a set of tuples that only differ on one variable, and that -// cover the domain of the variable. In that case, it will keep only one tuple, -// and replace the value for variable by any_value, the equivalent of '*' in -// regexps. -// -// This method is exposed for testing purposes. -constexpr int64_t kTableAnyValue = std::numeric_limits::min(); -void CompressTuples(absl::Span domain_sizes, - std::vector>* tuples); - -// Similar to CompressTuples() but produces a final table where each cell is -// a set of value. This should result in a table that can still be encoded -// efficiently in SAT but with less tuples and thus less extra Booleans. Note -// that if a set of value is empty, it is interpreted at "any" so we can gain -// some space. -// -// The passed tuples vector is used as temporary memory and is detroyed. -// We interpret kTableAnyValue as an "any" tuple. -// -// TODO(user): To reduce memory, we could return some absl::Span in the last -// layer instead of vector. -// -// TODO(user): The final compression is depend on the order of the variables. -// For instance the table (1,1)(1,2)(1,3),(1,4),(2,3) can either be compressed -// as (1,*)(2,3) or (1,{1,2,4})({1,3},3). More experiment are needed to devise -// a better heuristic. It might for example be good to call CompressTuples() -// first. -std::vector>> FullyCompressTuples( - absl::Span domain_sizes, - std::vector>* tuples); - // Keep the top n elements from a stream of elements. // // TODO(user): We could use gtl::TopN when/if it gets open sourced. Note that diff --git a/ortools/sat/util_test.cc b/ortools/sat/util_test.cc new file mode 100644 index 0000000000..f0bd504d06 --- /dev/null +++ b/ortools/sat/util_test.cc @@ -0,0 +1,948 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/util.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/log/check.h" +#include "absl/numeric/int128.h" +#include "absl/random/random.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/logging.h" +#include "ortools/base/mathutil.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/util/random_engine.h" +#include "ortools/util/sorted_interval_list.h" + +using ::testing::UnorderedElementsAre; + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; + +TEST(CompactVectorVectorTest, EmptyCornerCases) { + CompactVectorVector storage; + EXPECT_EQ(storage.size(), 0); + + const int index = storage.Add(std::vector()); + EXPECT_EQ(storage.size(), 1); + EXPECT_EQ(storage[index].size(), 0); +} + +TEST(CompactVectorVectorTest, ResetFromFlatMapping) { + CompactVectorVector storage; + EXPECT_EQ(storage.size(), 0); + + const std::vector input = {2, 2, 1, 1, 1, 0, 0, 2, 2}; + storage.ResetFromFlatMapping(input, IdentityMap()); + + EXPECT_EQ(storage.size(), 3); + EXPECT_THAT(storage[0], ElementsAre(5, 6)); + EXPECT_THAT(storage[1], ElementsAre(2, 3, 4)); + EXPECT_THAT(storage[2], ElementsAre(0, 1, 7, 8)); +} + +TEST(CompactVectorVectorTest, RemoveBySwap) { + CompactVectorVector storage; + EXPECT_EQ(storage.size(), 0); + + const std::vector input = {2, 2, 1, 1, 1, 0, 0, 2, 2}; + storage.ResetFromFlatMapping(input, IdentityMap()); + + EXPECT_EQ(storage.size(), 3); + EXPECT_THAT(storage[0], ElementsAre(5, 6)); + EXPECT_THAT(storage[1], ElementsAre(2, 3, 4)); + EXPECT_THAT(storage[2], ElementsAre(0, 1, 7, 8)); + + storage.RemoveBySwap(1, 1); + EXPECT_THAT(storage[1], ElementsAre(2, 4)); + + storage.RemoveBySwap(2, 1); + EXPECT_THAT(storage[2], ElementsAre(0, 8, 7)); +} + +TEST(CompactVectorVectorTest, ShrinkValues) { + CompactVectorVector storage; + EXPECT_EQ(storage.size(), 0); + + const std::vector input = {2, 2, 1, 1, 1, 0, 0, 2, 2}; + storage.ResetFromFlatMapping(input, IdentityMap()); + + EXPECT_EQ(storage.size(), 3); + EXPECT_THAT(storage[0], ElementsAre(5, 6)); + EXPECT_THAT(storage[1], ElementsAre(2, 3, 4)); + EXPECT_THAT(storage[2], ElementsAre(0, 1, 7, 8)); + + storage.ReplaceValuesBySmallerSet(2, {3, 4, 5}); + EXPECT_THAT(storage[2], ElementsAre(3, 4, 5)); +} + +TEST(CompactVectorVectorTest, ResetFromTranspose) { + CompactVectorVector storage; + EXPECT_EQ(storage.size(), 0); + + const std::vector keys = {2, 2, 1, 1, 1, 0, 0, 2, 2}; + const std::vector values = {3, 4, 0, 0, 1, 5, 1, 2, 2}; + storage.ResetFromFlatMapping(keys, values); + + ASSERT_EQ(storage.size(), 3); + EXPECT_THAT(storage[0], ElementsAre(5, 1)); + EXPECT_THAT(storage[1], ElementsAre(0, 0, 1)); + EXPECT_THAT(storage[2], ElementsAre(3, 4, 2, 2)); + + CompactVectorVector transpose; + transpose.ResetFromTranspose(storage); + + ASSERT_EQ(transpose.size(), 6); + EXPECT_THAT(transpose[0], ElementsAre(1, 1)); + EXPECT_THAT(transpose[1], ElementsAre(0, 1)); + EXPECT_THAT(transpose[2], ElementsAre(2, 2)); + EXPECT_THAT(transpose[3], ElementsAre(2)); + EXPECT_THAT(transpose[4], ElementsAre(2)); + EXPECT_THAT(transpose[5], ElementsAre(0)); + + // Note that retransposing sorts ! + CompactVectorVector second_transpose; + second_transpose.ResetFromTranspose(transpose); + + ASSERT_EQ(second_transpose.size(), 3); + EXPECT_THAT(second_transpose[0], ElementsAre(1, 5)); + EXPECT_THAT(second_transpose[1], ElementsAre(0, 0, 1)); + EXPECT_THAT(second_transpose[2], ElementsAre(2, 2, 3, 4)); +} + +TEST(FormatCounterTest, BasicCases) { + EXPECT_EQ("12", FormatCounter(12)); + EXPECT_EQ("12'345", FormatCounter(12345)); + EXPECT_EQ("123'456'789", FormatCounter(123456789)); +} + +TEST(FormatTable, BasicAlign) { + std::vector> table{ + {"x", "x", "x", "x", "x"}, + {FormatName("xx"), "xx", "xx", "xx", "xx"}, + {FormatName("xxx"), "xxx", "xxx", "xxx", "xxx"}}; + + EXPECT_EQ( + "x x x x x\n" + " 'xx': xx xx xx xx\n" + " 'xxx': xxx xxx xxx xxx\n", + FormatTable(table)); +} + +TEST(ModularInverseTest, AllSmallValues) { + for (int64_t m = 1; m < 1000; ++m) { + for (int64_t x = 1; x < m; ++x) { + const int64_t inverse = ModularInverse(x, m); + ASSERT_GE(inverse, 0); + ASSERT_LT(inverse, m); + if (inverse == 0) { + ASSERT_NE(std::gcd(x, m), 1); + } else { + ASSERT_EQ(x * inverse % m, 1); + } + } + } +} + +TEST(ModularInverseTest, BasicOverflowTest) { + absl::BitGen random; + const int64_t max = std::numeric_limits::max(); + for (int i = 0; i < 100000; ++i) { + const int64_t m = max - absl::LogUniform(random, 0, max); + const int64_t x = absl::Uniform(random, 0, m); + const int64_t inverse = ModularInverse(x, m); + ASSERT_GE(inverse, 0); + ASSERT_LT(inverse, m); + if (inverse == 0) { + ASSERT_NE(std::gcd(x, m), 1); + } else { + absl::int128 test_x = x; + absl::int128 test_inverse = inverse; + absl::int128 test_m = m; + ASSERT_EQ(test_x * test_inverse % test_m, 1); + } + } +} + +TEST(ProductWithodularInverseTest, FewSmallValues) { + const int limit = 50; + for (int64_t mod = 1; mod < limit; ++mod) { + for (int64_t coeff = -limit; coeff < limit; ++coeff) { + if (coeff == 0 || std::gcd(mod, std::abs(coeff)) != 1) continue; + for (int64_t rhs = -mod; rhs < mod; ++rhs) { + const int64_t result = ProductWithModularInverse(coeff, mod, rhs); + for (int64_t test = -limit; test < limit; ++test) { + const int64_t x = test * mod + result; + ASSERT_EQ(PositiveMod(x * coeff, mod), PositiveMod(rhs, mod)); + } + } + } + } +} + +TEST(SolveDiophantineEquationOfSizeTwoTest, FewSmallValues) { + const int limit = 50; + for (int64_t a = -limit; a < limit; ++a) { + if (a == 0) continue; + for (int64_t b = -limit; b < limit; ++b) { + if (b == 0) continue; + for (int64_t c = -limit; c < limit; ++c) { + int64_t ca = a; + int64_t cb = b; + int64_t cc = c; + int64_t x0, y0; + const bool r = SolveDiophantineEquationOfSizeTwo(ca, cb, cc, x0, y0); + if (!r) { + // This is the only case. + const int gcd = std::gcd(std::abs(a), std::abs(b)); + CHECK_GT(gcd, 1); + ASSERT_NE(c % gcd, 0); + continue; + } + ASSERT_EQ(ca * x0 + cb * y0, cc); + } + } + } +} + +TEST(SolveDiophantineEquationOfSizeTwoTest, BasicOverflowTest) { + absl::BitGen random; + const int64_t max = std::numeric_limits::max(); + for (int i = 0; i < 100000; ++i) { + int64_t a = max - absl::LogUniform(random, 0, max); + int64_t b = max - absl::LogUniform(random, 0, max); + int64_t cte = absl::Uniform(random, 0, max); + if (absl::Bernoulli(random, 0.5)) a = -a; + if (absl::Bernoulli(random, 0.5)) b = -b; + if (absl::Bernoulli(random, 0.5)) cte = -cte; + + int64_t x0, y0; + if (!SolveDiophantineEquationOfSizeTwo(a, b, cte, x0, y0)) { + // This is the only case. + const int64_t gcd = std::gcd(std::abs(a), std::abs(b)); + CHECK_GT(gcd, 1); + ASSERT_NE(cte % gcd, 0); + continue; + } + ASSERT_EQ( + absl::int128{a} * absl::int128{x0} + absl::int128{b} * absl::int128{y0}, + absl::int128{cte}); + } +} + +TEST(ClosestMultipleTest, BasicCases) { + EXPECT_EQ(ClosestMultiple(9, 10), 10); + EXPECT_EQ(ClosestMultiple(-9, 10), -10); + EXPECT_EQ(ClosestMultiple(5, 10), 0); + EXPECT_EQ(ClosestMultiple(-5, 10), 0); + EXPECT_EQ(ClosestMultiple(6, 10), 10); + EXPECT_EQ(ClosestMultiple(-6, 10), -10); + EXPECT_EQ(ClosestMultiple(789, 10), 790); +} + +TEST(LinearInequalityCanBeReducedWithClosestMultipleTest, BasicCase) { + std::vector coeffs = {99, 101}; + std::vector lbs = {-10, -10}; + std::vector ubs = {10, 10}; + + // Trivially true case. + int64_t new_rhs; + EXPECT_TRUE(LinearInequalityCanBeReducedWithClosestMultiple( + 100, coeffs, lbs, ubs, 10000000, &new_rhs)); + EXPECT_EQ(new_rhs, 20); + + // X + Y <= 3 case + EXPECT_TRUE(LinearInequalityCanBeReducedWithClosestMultiple( + 100, coeffs, lbs, ubs, 350, &new_rhs)); + EXPECT_EQ(new_rhs, 3); + + // X + Y <= 3, limit case. + // + // It doesn't work with 316 since 10 * 101 - 7 * 99 = 317. + EXPECT_TRUE(LinearInequalityCanBeReducedWithClosestMultiple( + 100, coeffs, lbs, ubs, 317, &new_rhs)); + EXPECT_EQ(new_rhs, 3); + + // False case: we cannot reduce the equation to a multiple of 100. + EXPECT_FALSE(LinearInequalityCanBeReducedWithClosestMultiple( + 100, coeffs, lbs, ubs, 316, &new_rhs)); +} + +TEST(LinearInequalityCanBeReducedWithClosestMultipleTest, Random) { + absl::BitGen random; + int num_reductions = 0; + const int num_tests = 100; + const int num_terms = 5; + const int64_t base = 10000; + for (int test = 0; test < num_tests; ++test) { + // We generate a random expression around 10 k + [-10, 10] + std::vector coeffs; + std::vector lbs(num_terms, -1); + std::vector ubs(num_terms, +1); + int64_t max_activity = 0; + for (int i = 0; i < num_terms; ++i) { + coeffs.push_back(base + absl::Uniform(random, -10, 10)); + max_activity += + std::max(coeffs.back() * ubs.back(), coeffs.back() * lbs.back()); + } + const int64_t target = max_activity - 2 * base; + const int64_t rhs = absl::Uniform(random, target - 50, target + 50); + + int64_t new_rhs; + const bool ok = LinearInequalityCanBeReducedWithClosestMultiple( + base, coeffs, lbs, ubs, rhs, &new_rhs); + if (!ok) continue; + + VLOG(2) << absl::StrJoin(coeffs, ", ") << " <= " << rhs << " new " + << new_rhs; + + // Test that the set of solutions is the same. + for (int number = 0; number < pow(3, num_terms); ++number) { + int temp = number; + int64_t activity = 0; + int64_t new_activity = 0; + for (int i = 0; i < num_terms; ++i) { + const int x = (temp % 3) - 1; + temp /= 3; + activity += coeffs[i] * x; + new_activity += x; + } + if (activity <= rhs) { + ASSERT_LE(new_activity, new_rhs); + } else { + ASSERT_GT(new_activity, new_rhs); + } + } + + ++num_reductions; + } + + // Over 10k runs, this worked. So we simplify sometimes but not always. + VLOG(2) << num_reductions; + EXPECT_GE(num_reductions, 10); + EXPECT_LT(num_reductions, num_tests); +} + +TEST(MoveOneUnprocessedLiteralLastTest, CorrectBehavior) { + absl::btree_set moved_last; + std::vector literals; + for (int i = 0; i < 100; ++i) { + literals.push_back(Literal(BooleanVariable(i), true)); + } + + int i = 0; + while (MoveOneUnprocessedLiteralLast(moved_last, literals.size(), + &literals) != -1) { + ++i; + EXPECT_FALSE(moved_last.contains(literals.back().Index())); + moved_last.insert(literals.back().Index()); + } + EXPECT_EQ(i, literals.size()); + + // No change in the actual literals. + std::sort(literals.begin(), literals.end()); + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(literals[i], Literal(BooleanVariable(i), true)); + } +} + +int SumOfPrefixesForSize(int n) { + absl::btree_set moved_last; + std::vector literals; + for (int i = 0; i < n; ++i) { + literals.push_back(Literal(BooleanVariable(i), true)); + } + int s = 0; + int result = 0; + std::vector before, after; + while (true) { + before = literals; + s = MoveOneUnprocessedLiteralLast(moved_last, literals.size(), &literals); + if (s == -1) return result; + + moved_last.insert(literals.back().Index()); + result += n - s; + + // Check that s is an actual prefix size. + after = literals; + before.resize(s); + after.resize(s); + EXPECT_EQ(before, after); + } + return result; +} + +TEST(MoveOneUnprocessedLiteralLastTest, CorrectComplexity) { + EXPECT_EQ(SumOfPrefixesForSize(0), 0); + EXPECT_EQ(SumOfPrefixesForSize(1), 0); + EXPECT_EQ(SumOfPrefixesForSize(2), 2); + EXPECT_EQ(SumOfPrefixesForSize(3), 5); + EXPECT_EQ(SumOfPrefixesForSize(4), 4 * log2(4)); + + // Note that this one can be done in 12 with S(5) = S(2) + 5 + S(3), so our + // algorithm is suboptimal for non power of 2 sizes starting from here. + EXPECT_EQ(SumOfPrefixesForSize(5), 13); + EXPECT_EQ(SumOfPrefixesForSize(6), 16); + EXPECT_EQ(SumOfPrefixesForSize(7), 20); + EXPECT_EQ(SumOfPrefixesForSize(8), 8 * log2(8)); + EXPECT_EQ(SumOfPrefixesForSize(9), 33); + EXPECT_EQ(SumOfPrefixesForSize(10), 36); + EXPECT_EQ(SumOfPrefixesForSize(100), 688); + EXPECT_EQ(SumOfPrefixesForSize(1000), 9984); + EXPECT_EQ(SumOfPrefixesForSize(1024), 1024 * log2(1024)); +} + +constexpr double kTolerance = 1e-6; + +TEST(IncrementalAverage, PositiveData) { + IncrementalAverage positive_data; + for (int i = 1; i < 101; ++i) { + positive_data.AddData(i); + EXPECT_EQ(i, positive_data.NumRecords()); + EXPECT_NEAR((i + 1) / 2.0, positive_data.CurrentAverage(), kTolerance); + } +} + +TEST(IncrementalAverage, NegativeData) { + IncrementalAverage negative_data; + for (int i = 1; i < 101; ++i) { + negative_data.AddData(-i); + EXPECT_EQ(i, negative_data.NumRecords()); + EXPECT_NEAR(-(i + 1) / 2.0, negative_data.CurrentAverage(), kTolerance); + } +} + +TEST(IncrementalAverage, MixedData) { + IncrementalAverage data; + EXPECT_EQ(0, data.NumRecords()); + EXPECT_EQ(0.0, data.CurrentAverage()); + + data.AddData(-1); + EXPECT_EQ(1, data.NumRecords()); + EXPECT_EQ(-1.0, data.CurrentAverage()); + + data.AddData(0); + EXPECT_EQ(2, data.NumRecords()); + EXPECT_NEAR(-1.0 / 2.0, data.CurrentAverage(), kTolerance); + + data.AddData(1); + EXPECT_EQ(3, data.NumRecords()); + EXPECT_NEAR(0.0, data.CurrentAverage(), kTolerance); +} + +TEST(IncrementalAverage, InitialFeed) { + IncrementalAverage data(5.0); + EXPECT_EQ(0, data.NumRecords()); + EXPECT_EQ(5.0, data.CurrentAverage()); +} + +TEST(IncrementalAverage, Reset) { + IncrementalAverage data; + data.AddData(5.0); + EXPECT_EQ(1, data.NumRecords()); + EXPECT_EQ(5.0, data.CurrentAverage()); + + data.Reset(3.0); + EXPECT_EQ(0, data.NumRecords()); + EXPECT_EQ(3.0, data.CurrentAverage()); +} + +TEST(ExponentialMovingAverage, Average) { + ExponentialMovingAverage data(/*decaying_factor=*/0.1); + EXPECT_EQ(0, data.NumRecords()); + EXPECT_EQ(0.0, data.CurrentAverage()); + + data.AddData(10.0); + EXPECT_EQ(1, data.NumRecords()); + EXPECT_EQ(10.0, data.CurrentAverage()); + + data.AddData(20.0); + EXPECT_EQ(2, data.NumRecords()); + EXPECT_NEAR(19.0, data.CurrentAverage(), kTolerance); + + data.AddData(30); + EXPECT_EQ(3, data.NumRecords()); + EXPECT_NEAR(28.9, data.CurrentAverage(), kTolerance); +} + +TEST(Percentile, BasicTest) { + // Example at https://en.wikipedia.org/wiki/Percentile + Percentile data(/*record_limit=*/5); + EXPECT_EQ(0, data.NumRecords()); + + data.AddRecord(15.0); + data.AddRecord(20.0); + data.AddRecord(35.0); + data.AddRecord(40.0); + data.AddRecord(50.0); + EXPECT_EQ(5, data.NumRecords()); + + EXPECT_NEAR(15.0, data.GetPercentile(5), kTolerance); + EXPECT_NEAR(20.0, data.GetPercentile(30), kTolerance); + EXPECT_NEAR(27.5, data.GetPercentile(40), kTolerance); + EXPECT_NEAR(50, data.GetPercentile(95), kTolerance); +} + +TEST(Percentile, RecordLimit) { + Percentile data(/*record_limit=*/2); + EXPECT_EQ(0, data.NumRecords()); + + data.AddRecord(15.0); + EXPECT_EQ(1, data.NumRecords()); + data.AddRecord(20.0); + EXPECT_EQ(2, data.NumRecords()); + EXPECT_NEAR(15.0, data.GetPercentile(10), kTolerance); + EXPECT_NEAR(20.0, data.GetPercentile(90), kTolerance); + data.AddRecord(35.0); + data.AddRecord(40.0); + EXPECT_EQ(2, data.NumRecords()); + EXPECT_NEAR(35.0, data.GetPercentile(10), kTolerance); + EXPECT_NEAR(40.0, data.GetPercentile(90), kTolerance); +} + +TEST(Percentile, BasicTest2) { + Percentile data(/*record_limit=*/10); + EXPECT_EQ(0, data.NumRecords()); + + data.AddRecord(6.0753); + data.AddRecord(8.6678); + data.AddRecord(0.4823); + data.AddRecord(6.7243); + data.AddRecord(5.6375); + data.AddRecord(2.3846); + data.AddRecord(4.1328); + data.AddRecord(5.6852); + data.AddRecord(12.1568); + data.AddRecord(10.5389); + EXPECT_EQ(10, data.NumRecords()); + + EXPECT_NEAR(5.6709, data.GetPercentile(42), 1e-5); +} + +TEST(Percentile, RandomNumbers) { + const int record_limit = 1000; + Percentile data(record_limit); + EXPECT_EQ(0, data.NumRecords()); + + std::vector records; + absl::BitGen random; + for (int i = 0; i < record_limit; ++i) { + double record = absl::Uniform(random, -10000, 10000); + records.push_back(record); + data.AddRecord(record); + } + EXPECT_EQ(record_limit, data.NumRecords()); + std::sort(records.begin(), records.end()); + for (int i = 0; i < record_limit; ++i) { + EXPECT_NEAR(records[i], + data.GetPercentile((i + 0.5) * 100.0 / record_limit), + kTolerance); + } +} + +TEST(SafeDoubleToInt64Test, BasicCases) { + const double kInfinity = std::numeric_limits::infinity(); + const int64_t kMax = std::numeric_limits::max(); + const int64_t kMin = std::numeric_limits::min(); + const int64_t max53 = (int64_t{1} << 53) - 1; + + // Arbitrary behavior for nans. + EXPECT_EQ(SafeDoubleToInt64(std::numeric_limits::quiet_NaN()), 0); + EXPECT_EQ(SafeDoubleToInt64(std::numeric_limits::signaling_NaN()), 0); + + EXPECT_EQ(SafeDoubleToInt64(static_cast(kMax)), kMax); + EXPECT_EQ(SafeDoubleToInt64(kInfinity), kMax); + + // Transition for max. + for (int i = 0; i < 512; ++i) { + ASSERT_EQ(SafeDoubleToInt64(static_cast(kMax - i)), kMax); + } + for (int i = 512; i < 1024 + 511; ++i) { + ASSERT_EQ(SafeDoubleToInt64(static_cast(kMax - i)), kMax - 1023); + } + ASSERT_EQ(SafeDoubleToInt64(static_cast(kMax - 1024 - 511)), + kMax - 2047); + + // Transition for max precision. + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(SafeDoubleToInt64(static_cast(max53 - i)), max53 - i); + } + int num_error = 0; + for (int i = 0; i < 10; ++i) { + if (SafeDoubleToInt64(static_cast(max53 + i)) != max53 + i) { + ++num_error; + } + } + EXPECT_EQ(num_error, 4); + + // static_cast just truncate the number... + EXPECT_EQ(SafeDoubleToInt64(0.1), 0); + EXPECT_EQ(SafeDoubleToInt64(0.9), 0); + + EXPECT_EQ(SafeDoubleToInt64(static_cast(kMin)), kMin); + EXPECT_EQ(SafeDoubleToInt64(-kInfinity), kMin); +} + +TEST(MaxBoundedSubsetSumTest, LowMaxValue) { + const int bound = 49; + MaxBoundedSubsetSum bounded_subset_sum(bound); + for (int gcd = 2; gcd < 10; ++gcd) { + bounded_subset_sum.Reset(bound); + for (int i = 0; i < 1000; ++i) { + bounded_subset_sum.Add((i % 50) * gcd); + } + EXPECT_EQ(bounded_subset_sum.CurrentMax(), bound / gcd * gcd); + } +} + +TEST(MaxBoundedSubsetSumTest, LowNumberOfElement) { + MaxBoundedSubsetSum bounded_subset_sum(178'979); + bounded_subset_sum.Add(150'000); + bounded_subset_sum.Add(28'000); + bounded_subset_sum.Add(1000); + EXPECT_EQ(bounded_subset_sum.CurrentMax(), 178'000); + + // Too many elements causes an "abort" and we just return the bound. + for (int i = 0; i < 10; ++i) { + bounded_subset_sum.Add(i); + } + EXPECT_EQ(bounded_subset_sum.CurrentMax(), 178'979); +} + +TEST(MaxBoundedSubsetSumTest, FailBackToGcd) { + MaxBoundedSubsetSum bounded_subset_sum(/*bound=*/10122); + bounded_subset_sum.AddMultiples(100, 10); + EXPECT_EQ(bounded_subset_sum.CurrentMax(), 1000); + + bounded_subset_sum.AddMultiples(200, 1000); + EXPECT_EQ(bounded_subset_sum.CurrentMax(), 10100); + + // We could have better bounding maybe in this case. + bounded_subset_sum.Add(1); + EXPECT_EQ(bounded_subset_sum.CurrentMax(), 10122); +} + +TEST(MaxBoundedSubsetSumTest, SimpleMultiChoice) { + MaxBoundedSubsetSum bounded_subset_sum(34); + bounded_subset_sum.AddChoices({3, 10, 19}); + bounded_subset_sum.AddChoices({0, 2, 4, 40}); + bounded_subset_sum.AddChoices({3, 7, 8, 16}); + EXPECT_EQ(bounded_subset_sum.CurrentMax(), 31); +} + +TEST(MaxBoundedSubsetSumTest, CheckMaxIfAdded) { + MaxBoundedSubsetSum bounded_subset_sum(34); + bounded_subset_sum.Add(10); + bounded_subset_sum.Add(10); + bounded_subset_sum.Add(10); + EXPECT_EQ(bounded_subset_sum.MaxIfAdded(12), 32); + EXPECT_EQ(bounded_subset_sum.MaxIfAdded(15), 30); + EXPECT_EQ(bounded_subset_sum.MaxIfAdded(34), 34); + for (int i = 0; i < 100; ++i) { + bounded_subset_sum.Add(18); + } + EXPECT_EQ(bounded_subset_sum.CurrentMax(), 30); + EXPECT_EQ(bounded_subset_sum.MaxIfAdded(5), 33); +} + +static void BM_bounded_subset_sum(benchmark::State& state) { + random_engine_t random_; + const int num_items = state.range(0); + const int num_choices = state.range(1); + const int max_capacity = state.range(2); + const int max_size = state.range(3); + + const int num_updates = num_items * num_choices; + const int capacity = std::uniform_int_distribution( + max_capacity / 2, max_capacity)(random_); + MaxBoundedSubsetSum subset_sum(capacity); + std::uniform_int_distribution size_dist(0, max_size); + std::vector choices(num_choices); + for (auto _ : state) { + subset_sum.Reset(capacity); + for (int i = 0; i < num_items; ++i) { + choices.clear(); + for (int j = 0; j < num_choices; ++j) { + choices[j] = size_dist(random_); + } + subset_sum.AddChoices(choices); + } + } + // Number of updates. + state.SetBytesProcessed(static_cast(state.iterations()) * + num_updates); +} + +BENCHMARK(BM_bounded_subset_sum) + ->Args({10, 3, 30, 5}) + ->Args({10, 4, 50, 10}) + ->Args({10, 4, 30, 20}) + ->Args({25, 3, 30, 5}) + ->Args({25, 4, 50, 10}) + ->Args({25, 4, 30, 20}) + ->Args({60, 3, 30, 5}) + ->Args({60, 4, 50, 10}) + ->Args({60, 4, 30, 20}) + ->Args({100, 3, 30, 5}) + ->Args({100, 4, 50, 10}) + ->Args({100, 4, 30, 20}); + +TEST(FirstFewValuesTest, Basic) { + FirstFewValues<8> values; + EXPECT_EQ(values.LastValue(), std::numeric_limits::max()); + values.Add(3); + EXPECT_THAT(values.reachable(), ElementsAre(0, 3, 6, 9, 12, 15, 18, 21)); + values.Add(5); + EXPECT_THAT(values.reachable(), ElementsAre(0, 3, 5, 6, 8, 9, 10, 11)); + + EXPECT_TRUE(values.MightBeReachable(0)); + EXPECT_TRUE(values.MightBeReachable(100)); + EXPECT_FALSE(values.MightBeReachable(2)); + EXPECT_FALSE(values.MightBeReachable(7)); +} + +TEST(FirstFewValuesTest, Overflow) { + FirstFewValues<6> values; + + const int64_t max = std::numeric_limits::max(); + const int64_t v = max / 3; + values.Add(v); + EXPECT_THAT(values.reachable(), ElementsAre(0, v, 2 * v, 3 * v, max, max)); +} + +TEST(BasicKnapsackSolverTest, BasicFeasibleExample) { + std::vector domains = {Domain(-3, 14), Domain(1, 15)}; + std::vector coeffs = {7, 13}; + std::vector costs = {5, 8}; + Domain rhs(100, 200); + + BasicKnapsackSolver solver; + const auto& result = solver.Solve(domains, coeffs, costs, rhs); + EXPECT_TRUE(result.solved); + EXPECT_FALSE(result.infeasible); + EXPECT_THAT(result.solution, ElementsAre(-2, 9)); +} + +TEST(BasicKnapsackSolverTest, BasicInfesibleExample) { + std::vector domains = {Domain(-3, 8), Domain(1, 8)}; + std::vector coeffs = {7, 13}; + std::vector costs = {5, 8}; + Domain rhs(103); + + BasicKnapsackSolver solver; + const auto& result = solver.Solve(domains, coeffs, costs, rhs); + EXPECT_TRUE(result.solved); + EXPECT_TRUE(result.infeasible); +} + +TEST(BasicKnapsackSolverTest, RandomComparisonWithSolver) { + const int num_vars = 6; + absl::BitGen random; + + BasicKnapsackSolver solver; + for (int num_tests = 0; num_tests < 100; ++num_tests) { + std::vector domains; + std::vector coeffs; + std::vector costs; + for (int i = 0; i < num_vars; ++i) { + int a = absl::Uniform(random, -10, 10); + int b = absl::Uniform(random, -10, 10); + if (a > b) std::swap(a, b); + + domains.push_back(Domain(a, b)); + costs.push_back(absl::Uniform(random, -10, 10)); + coeffs.push_back(absl::Uniform(random, -10, 9)); + if (coeffs.back() >= 0) coeffs.back()++; + } + const int c = absl::Uniform(random, num_vars * 5, num_vars * 10); + Domain rhs = Domain(c, c + absl::Uniform(random, 0, 5)); + + // Create corresponding proto. + CpModelProto proto; + auto* linear = proto.add_constraints()->mutable_linear(); + auto* objective = proto.mutable_objective(); + for (int i = 0; i < num_vars; ++i) { + auto* var = proto.add_variables(); + FillDomainInProto(domains[i], var); + linear->add_vars(i); + linear->add_coeffs(coeffs[i]); + objective->add_vars(i); + objective->add_coeffs(costs[i]); + } + FillDomainInProto(rhs, linear); + + const auto& result = solver.Solve(domains, coeffs, costs, rhs); + CHECK(result.solved); // We should always be able to solve here. + + SatParameters params; + params.set_cp_model_presolve(false); + const CpSolverResponse response = SolveWithParameters(proto, params); + + if (result.infeasible) { + EXPECT_EQ(response.status(), INFEASIBLE); + } else { + EXPECT_EQ(response.status(), OPTIMAL); + int64_t objective = 0; + for (int i = 0; i < num_vars; ++i) { + EXPECT_TRUE(domains[i].Contains(result.solution[i])) + << domains[i] << " " << result.solution[i] << " " << coeffs[i]; + objective += costs[i] * result.solution[i]; + } + EXPECT_DOUBLE_EQ(objective, response.objective_value()); + } + } +} + +using CeilFloorTest = testing::TestWithParam>; + +TEST_P(CeilFloorTest, FloorOfRatioInt) { + const int a = std::get<0>(GetParam()); + const int b = std::get<1>(GetParam()); + if (b == 0) return; + int q = MathUtil::FloorOfRatio(a, b); + // a / b - 1 < q <= a / b + EXPECT_LT(static_cast(a) / static_cast(b) - 1.0, + static_cast(q)); + EXPECT_LE(static_cast(q), + static_cast(a) / static_cast(b)); +} + +TEST_P(CeilFloorTest, FloorOfRatioInt128) { + const absl::int128 a = std::get<0>(GetParam()); + const absl::int128 b = std::get<1>(GetParam()); + if (b == 0) return; + absl::int128 q = MathUtil::FloorOfRatio(a, b); + // a / b - 1 < q <= a / b + EXPECT_LT(static_cast(a) / static_cast(b) - 1.0, + static_cast(q)); + EXPECT_LE(static_cast(q), + static_cast(a) / static_cast(b)); +} + +TEST_P(CeilFloorTest, CeilOfRatioInt) { + const int a = std::get<0>(GetParam()); + const int b = std::get<1>(GetParam()); + if (b == 0) return; + int q = MathUtil::CeilOfRatio(a, b); + // a / b <= q < a / b + 1 + EXPECT_LE(static_cast(a) / static_cast(b), + static_cast(q)); + EXPECT_LE(static_cast(q), + static_cast(a) / static_cast(b) + 1.0); +} + +TEST_P(CeilFloorTest, CeilOfRatioInt128) { + const absl::int128 a = std::get<0>(GetParam()); + const absl::int128 b = std::get<1>(GetParam()); + if (b == 0) return; + absl::int128 q = MathUtil::CeilOfRatio(a, b); + // a / b <= q < a / b + 1 + EXPECT_LE(static_cast(a) / static_cast(b), + static_cast(q)); + EXPECT_LE(static_cast(q), + static_cast(a) / static_cast(b) + 1.0); +} + +INSTANTIATE_TEST_SUITE_P(CeilFloorTests, CeilFloorTest, + testing::Combine(testing::Range(-10, 10), + testing::Range(-10, 10))); + +TEST(TopNTest, BasicBehavior) { + TopN top3(3); + top3.Add(1, 1.0); + top3.Add(2, 2.0); + EXPECT_THAT(top3.UnorderedElements(), ElementsAre(1, 2)); + top3.Add(3, 2.0); + EXPECT_THAT(top3.UnorderedElements(), ElementsAre(1, 2, 3)); + top3.Add(4, 7); + EXPECT_THAT(top3.UnorderedElements(), ElementsAre(4, 2, 3)); +} + +TEST(TopNTest, Random) { + TopN topN(4); + std::vector input; + for (int i = 0; i < 1000; ++i) input.push_back(i); + std::shuffle(input.begin(), input.end(), absl::BitGen()); + for (const int value : input) topN.Add(value, value); + EXPECT_THAT(topN.UnorderedElements(), + UnorderedElementsAre(999, 998, 997, 996)); +} + +TEST(AtMostOneDecompositionTest, DetectFullClique) { + std::vector> graph{ + {1, 2, 3}, {0, 2, 3}, {0, 1, 3}, {0, 1, 2}}; + std::vector buffer; + absl::BitGen random; + const auto decompo = AtMostOneDecomposition(graph, random, &buffer); + EXPECT_THAT(decompo, ElementsAre(UnorderedElementsAre(0, 1, 2, 3))); +} + +TEST(AtMostOneDecompositionTest, DetectDisjointCliques) { + std::vector> graph{ + {1, 2, 3}, {0, 2, 3}, {0, 1, 3}, {0, 1, 2}, {5, 6}, {4, 6}, {4, 5}}; + std::vector buffer; + absl::BitGen random; + const auto decompo = AtMostOneDecomposition(graph, random, &buffer); + EXPECT_THAT(decompo, UnorderedElementsAre(UnorderedElementsAre(0, 1, 2, 3), + UnorderedElementsAre(4, 5, 6))); +} + +TEST(WeightedPickTest, SizeOne) { + std::vector weights = {123.4}; + absl::BitGen random; + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(WeightedPick(weights, random), 0); + } +} + +TEST(WeightedPickTest, SimpleTest) { + std::vector weights = {1.0, 2.0, 3.0}; + absl::BitGen random; + const int kSample = 1e6; + std::vector counts(3, 0); + for (int i = 0; i < kSample; ++i) { + counts[WeightedPick(weights, random)]++; + } + for (int i = 0; i < weights.size(); ++i) { + EXPECT_LE( + std::abs(weights[i] / 6.0 - static_cast(counts[i]) * 1e-6), + 1e-2); + } +} + +} // namespace +} // namespace sat +} // namespace operations_research