Skip to content

Commit

Permalink
Repeating life cycles (#168)
Browse files Browse the repository at this point in the history
Summary:
- The current syntax only allows the generation of a succession of method calls.
- We have changed the design slightly, we now want the user to provide a control flow graph structure.

Pull Request resolved: #168

Reviewed By: anwesht

Differential Revision: D61534984

Pulled By: arthaud

fbshipit-source-id: 510be1605be76effdd7a3386cf6a11cdcf473cdb
  • Loading branch information
ZeyadTarekk authored and facebook-github-bot committed Aug 21, 2024
1 parent d7e1c9d commit c20adf3
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 58 deletions.
150 changes: 115 additions & 35 deletions source/LifecycleMethod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,56 @@

namespace marianatrench {

bool LifecycleGraphNode::operator==(const LifecycleGraphNode& other) const {
return method_calls_ == other.method_calls_ &&
successors_ == other.successors_;
}

void LifeCycleMethodGraph::add_node(
const std::string& node_name,
std::vector<LifecycleMethodCall> method_calls,
std::vector<std::string> successors) {
nodes_.emplace(
node_name,
LifecycleGraphNode(std::move(method_calls), std::move(successors)));
}

bool LifeCycleMethodGraph::operator==(const LifeCycleMethodGraph& other) const {
return nodes_ == other.nodes_;
}

const LifecycleGraphNode* MT_NULLABLE
LifeCycleMethodGraph::get_node(const std::string& node_name) const {
auto it = nodes_.find(node_name);
if (it != nodes_.end()) {
return &it->second;
}
return nullptr;
}

LifeCycleMethodGraph LifeCycleMethodGraph::from_json(const Json::Value& value) {
LifeCycleMethodGraph graph;

for (const auto& node_name : value.getMemberNames()) {
const auto& node = value[node_name];

std::vector<LifecycleMethodCall> method_calls;
for (const auto& instruction :
JsonValidation::null_or_array(node, "instructions")) {
method_calls.push_back(LifecycleMethodCall::from_json(instruction));
}

std::vector<std::string> successors;
for (const auto& successor :
JsonValidation::null_or_array(node, "successors")) {
successors.push_back(JsonValidation::string(successor));
}

graph.add_node(node_name, std::move(method_calls), std::move(successors));
}
return graph;
}

LifecycleMethodCall LifecycleMethodCall::from_json(const Json::Value& value) {
auto method_name = JsonValidation::string(value, "method_name");
auto return_type = JsonValidation::string(value, "return_type");
Expand Down Expand Up @@ -142,14 +192,26 @@ bool LifecycleMethodCall::operator==(const LifecycleMethodCall& other) const {
}

LifecycleMethod LifecycleMethod::from_json(const Json::Value& value) {
auto base_class_name = JsonValidation::string(value, "base_class_name");
auto method_name = JsonValidation::string(value, "method_name");
std::vector<LifecycleMethodCall> callees;
for (const auto& callee : JsonValidation::nonempty_array(value, "callees")) {
callees.emplace_back(LifecycleMethodCall::from_json(callee));
std::string base_class_name =
JsonValidation::string(value, "base_class_name");
std::string method_name = JsonValidation::string(value, "method_name");
if (JsonValidation::has_field(value, "callees")) {
std::vector<LifecycleMethodCall> callees;
for (const auto& callee :
JsonValidation::nonempty_array(value, "callees")) {
callees.push_back(LifecycleMethodCall::from_json(callee));
}
return LifecycleMethod(base_class_name, method_name, std::move(callees));
} else if (JsonValidation::has_field(value, "control_flow_graph")) {
JsonValidation::validate_object(value, "control_flow_graph");
LifeCycleMethodGraph graph = LifeCycleMethodGraph::from_json(
JsonValidation::object(value, "control_flow_graph"));
return LifecycleMethod(base_class_name, method_name, std::move(graph));
}

return LifecycleMethod(base_class_name, method_name, callees);
throw JsonValidationError(
value,
/* field */ std::nullopt,
"key `callees` or `control_flow_graph`");
}

bool LifecycleMethod::validate(
Expand Down Expand Up @@ -178,8 +240,15 @@ bool LifecycleMethod::validate(
return false;
}

for (const auto& callee : callees_) {
callee.validate(base_class, class_hierarchies);
if (const auto* callees =
std::get_if<std::vector<LifecycleMethodCall>>(&body_)) {
for (const auto& callee : *callees) {
callee.validate(base_class, class_hierarchies);
}
} else {
// TODO:handle graph
const auto& graph = std::get<LifeCycleMethodGraph>(body_);
static_cast<void>(graph); // hide unused variable warning.
}

return true;
Expand All @@ -202,15 +271,20 @@ void LifecycleMethod::create_methods(
// in the DexMethod's code. The register location will be used to create the
// invoke operation for methods that take a given DexType* as its argument.
TypeIndexMap type_index_map;
for (const auto& callee : callees_) {
const auto* type_list = callee.get_argument_types();
if (type_list == nullptr) {
ERROR(1, "Callee `{}` has invalid argument types.", callee.to_string());
continue;
}
for (auto* type : *type_list) {
type_index_map.emplace(type, type_index_map.size() + 1);
if (const auto* callees =
std::get_if<std::vector<LifecycleMethodCall>>(&body_)) {
for (const auto& callee : *callees) {
const auto* type_list = callee.get_argument_types();
if (type_list == nullptr) {
ERROR(1, "Callee `{}` has invalid argument types.", callee.to_string());
continue;
}
for (auto* type : *type_list) {
type_index_map.emplace(type, type_index_map.size() + 1);
}
}
} else {
// TODO: Handle graph
}

auto* base_class_type = DexType::get_type(base_class_name_);
Expand Down Expand Up @@ -298,7 +372,7 @@ std::vector<const Method*> LifecycleMethod::get_methods_for_type(

bool LifecycleMethod::operator==(const LifecycleMethod& other) const {
return base_class_name_ == other.base_class_name_ &&
method_name_ == other.method_name_ && callees_ == other.callees_;
method_name_ == other.method_name_ && body_ == other.body_;
}

const DexMethod* MT_NULLABLE LifecycleMethod::create_dex_method(
Expand All @@ -319,26 +393,32 @@ const DexMethod* MT_NULLABLE LifecycleMethod::create_dex_method(
mt_assert(dex_klass != nullptr);

int callee_count = 0;
for (const auto& callee : callees_) {
auto* dex_method = callee.get_dex_method(dex_klass);
if (!dex_method) {
// Dex method does not apply for current APK.
// See `LifecycleMethod::validate()`.
continue;
}

++callee_count;
if (const auto* callees =
std::get_if<std::vector<LifecycleMethodCall>>(&body_)) {
for (const auto& callee : *callees) {
auto* dex_method = callee.get_dex_method(dex_klass);
if (!dex_method) {
// Dex method does not apply for current APK.
// See `LifecycleMethod::validate()`.
continue;
}

++callee_count;

std::vector<Location> invoke_with_registers{this_location};
auto* type_list = callee.get_argument_types();
// This should have been verified at the start of `create_methods`
mt_assert(type_list != nullptr);
for (auto* type : *type_list) {
auto argument_register = method.get_local(type_index_map.at(type));
invoke_with_registers.push_back(argument_register);
std::vector<Location> invoke_with_registers{this_location};
auto* type_list = callee.get_argument_types();
// This should have been verified at the start of `create_methods`
mt_assert(type_list != nullptr);
for (auto* type : *type_list) {
auto argument_register = method.get_local(type_index_map.at(type));
invoke_with_registers.push_back(argument_register);
}
main_block->invoke(
IROpcode::OPCODE_INVOKE_VIRTUAL, dex_method, invoke_with_registers);
}
main_block->invoke(
IROpcode::OPCODE_INVOKE_VIRTUAL, dex_method, invoke_with_registers);
} else {
// TODO: Handle graph
}

if (callee_count < 2) {
Expand Down
56 changes: 53 additions & 3 deletions source/LifecycleMethod.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#pragma once

#include <variant>

#include <fmt/format.h>
#include <json/json.h>

Expand Down Expand Up @@ -53,6 +55,10 @@ class LifecycleMethodCall {

static LifecycleMethodCall from_json(const Json::Value& value);

const std::string& get_method_name() const {
return method_name_;
}

void validate(
const DexClass* base_class,
const ClassHierarchies& class_hierarchies) const;
Expand Down Expand Up @@ -93,6 +99,49 @@ class LifecycleMethodCall {
std::optional<std::string> defined_in_derived_class_;
};

class LifecycleGraphNode {
public:
LifecycleGraphNode(
std::vector<LifecycleMethodCall> method_calls,
std::vector<std::string> successors)
: method_calls_(std::move(method_calls)),
successors_(std::move(successors)) {}

INCLUDE_DEFAULT_COPY_CONSTRUCTORS_AND_ASSIGNMENTS(LifecycleGraphNode)

const std::vector<LifecycleMethodCall>& method_calls() const {
return method_calls_;
}
const std::vector<std::string>& successors() const {
return successors_;
}
bool operator==(const LifecycleGraphNode& other) const;

private:
std::vector<LifecycleMethodCall> method_calls_;
std::vector<std::string> successors_;
};

class LifeCycleMethodGraph {
public:
LifeCycleMethodGraph() {}

INCLUDE_DEFAULT_COPY_CONSTRUCTORS_AND_ASSIGNMENTS(LifeCycleMethodGraph)

void add_node(
const std::string& node_name,
std::vector<LifecycleMethodCall> method_calls,
std::vector<std::string> successors);

const LifecycleGraphNode* get_node(const std::string& node_name) const;
bool operator==(const LifeCycleMethodGraph& other) const;

static LifeCycleMethodGraph from_json(const Json::Value& value);

private:
std::unordered_map<std::string, LifecycleGraphNode> nodes_;
};

/**
* A life-cycle method represents a collection of artificial DexMethods that
* simulate the life-cycle of a class.
Expand Down Expand Up @@ -126,10 +175,11 @@ class LifecycleMethod {
explicit LifecycleMethod(
std::string base_class_name,
std::string method_name,
std::vector<LifecycleMethodCall> callees)
std::variant<std::vector<LifecycleMethodCall>, LifeCycleMethodGraph>
callees)
: base_class_name_(std::move(base_class_name)),
method_name_(std::move(method_name)),
callees_(std::move(callees)) {}
body_(std::move(callees)) {}

INCLUDE_DEFAULT_COPY_CONSTRUCTORS_AND_ASSIGNMENTS(LifecycleMethod)

Expand Down Expand Up @@ -174,7 +224,7 @@ class LifecycleMethod {

std::string base_class_name_;
std::string method_name_;
std::vector<LifecycleMethodCall> callees_;
std::variant<std::vector<LifecycleMethodCall>, LifeCycleMethodGraph> body_;
ConcurrentMap<const DexType*, const Method*> class_to_lifecycle_method_;
};

Expand Down
Loading

0 comments on commit c20adf3

Please sign in to comment.