Skip to content

Commit

Permalink
Rename CalleeInterval
Browse files Browse the repository at this point in the history
Summary: Rename `CalleeInterval` to `CallClassIntervalContext`, within which we have the fields `callee_interval` and `preserves_type_context`, the latter actually applies to how the caller calls the callee.

Reviewed By: anwesht

Differential Revision: D48141877

fbshipit-source-id: b8566e8d362680db3c9d0c88b882b51ab04bdc4a
  • Loading branch information
Yuh Shin Ong authored and facebook-github-bot committed Aug 14, 2023
1 parent b554e07 commit cc5057b
Show file tree
Hide file tree
Showing 34 changed files with 350 additions and 330 deletions.
2 changes: 1 addition & 1 deletion source/ArtificialMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::vector<Model> ArtificialMethods::models(Context& context) const {
/* call_info */ CallInfo::declaration(),
/* field_callee */ nullptr,
/* call_position */ nullptr,
/* callee_interval */ CalleeInterval(),
/* class_interval_context */ CallClassIntervalContext(),
/* distance */ 0,
/* origins */ {},
/* field origins */ {},
Expand Down
2 changes: 1 addition & 1 deletion source/BackwardTaintTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ void check_flows_to_array_allocation(
/* call_info */ CallInfo::origin(),
/* field_callee */ nullptr,
/* call_position */ position,
/* callee_interval */ CalleeInterval(),
/* class_interval_context */ CallClassIntervalContext(),
/* distance */ 1,
/* origins */ MethodSet{array_allocation_method},
/* field_origins */ {},
Expand Down
43 changes: 43 additions & 0 deletions source/CallClassIntervalContext.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ostream>

#include <mariana-trench/CallClassIntervalContext.h>
#include <mariana-trench/Frame.h>
#include <mariana-trench/TaintConfig.h>

namespace marianatrench {

CallClassIntervalContext::CallClassIntervalContext(
ClassIntervals::Interval interval,
bool preserves_type_context)
: callee_interval_(std::move(interval)),
preserves_type_context_(preserves_type_context) {}

CallClassIntervalContext::CallClassIntervalContext(const TaintConfig& config)
: CallClassIntervalContext(config.class_interval_context()) {}

CallClassIntervalContext::CallClassIntervalContext(const Frame& frame)
: CallClassIntervalContext(frame.class_interval_context()) {}

Json::Value CallClassIntervalContext::to_json() const {
auto value = Json::Value(Json::objectValue);
value["callee_interval"] = ClassIntervals::interval_to_json(callee_interval_);
value["preserves_type_context"] = Json::Value(preserves_type_context_);
return value;
}

std::ostream& operator<<(
std::ostream& out,
const CallClassIntervalContext& class_interval_context) {
return out << "{" << class_interval_context.callee_interval()
<< ", preserves_type_context="
<< class_interval_context.preserves_type_context() << "}";
}

} // namespace marianatrench
41 changes: 21 additions & 20 deletions source/CalleeInterval.h → source/CallClassIntervalContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,43 @@ class Frame;
/**
* Represents the class interval of a callee in `Taint`.
*
* interval_:
* callee_interval_:
* Represents the class interval of the method based on the
* receiver's type.
* preserves_type_context_:
* True iff the callee was called with `this.` (i.e. the method call's
* receiver has the same type as the caller's class).
*/
class CalleeInterval {
class CallClassIntervalContext {
public:
CalleeInterval()
: interval_(ClassIntervals::Interval::top()),
CallClassIntervalContext()
: callee_interval_(ClassIntervals::Interval::top()),
preserves_type_context_(false) {
// Default constructor is expected to produce a "default" interval.
mt_assert(is_default());
}

explicit CalleeInterval(
explicit CallClassIntervalContext(
ClassIntervals::Interval interval,
bool preserves_type_context);

explicit CalleeInterval(const TaintConfig& config);
explicit CallClassIntervalContext(const TaintConfig& config);

explicit CalleeInterval(const Frame& frame);
explicit CallClassIntervalContext(const Frame& frame);

INCLUDE_DEFAULT_COPY_CONSTRUCTORS_AND_ASSIGNMENTS(CalleeInterval)
INCLUDE_DEFAULT_COPY_CONSTRUCTORS_AND_ASSIGNMENTS(CallClassIntervalContext)

bool operator==(const CalleeInterval& other) const {
return interval_ == other.interval_ &&
bool operator==(const CallClassIntervalContext& other) const {
return callee_interval_ == other.callee_interval_ &&
preserves_type_context_ == other.preserves_type_context_;
}

bool is_default() const {
return interval_.is_top() && !preserves_type_context_;
return callee_interval_.is_top() && !preserves_type_context_;
}

const ClassIntervals::Interval& interval() const {
return interval_;
const ClassIntervals::Interval& callee_interval() const {
return callee_interval_;
}

bool preserves_type_context() const {
Expand All @@ -66,26 +66,27 @@ class CalleeInterval {

friend std::ostream& operator<<(
std::ostream& out,
const CalleeInterval& interval);
const CallClassIntervalContext& interval);

private:
ClassIntervals::Interval interval_;
ClassIntervals::Interval callee_interval_;
bool preserves_type_context_;
};

} // namespace marianatrench

template <>
struct std::hash<marianatrench::CalleeInterval> {
std::size_t operator()(
const marianatrench::CalleeInterval& callee_interval) const {
struct std::hash<marianatrench::CallClassIntervalContext> {
std::size_t operator()(const marianatrench::CallClassIntervalContext&
class_interval_context) const {
std::size_t seed = 0;
boost::hash_combine(
seed,
std::hash<marianatrench::ClassIntervals::Interval>()(
callee_interval.interval()));
class_interval_context.callee_interval()));
boost::hash_combine(
seed, std::hash<bool>()(callee_interval.preserves_type_context()));
seed,
std::hash<bool>()(class_interval_context.preserves_type_context()));
return seed;
}
};
8 changes: 4 additions & 4 deletions source/CallPositionFrames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ CallPositionFrames CallPositionFrames::propagate(
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
const std::vector<std::optional<std::string>>& source_constant_arguments,
const CalleeInterval& callee_interval,
const CallClassIntervalContext& class_interval_context,
const ClassIntervals::Interval& caller_class_interval) const {
if (is_bottom()) {
return CallPositionFrames::bottom();
Expand All @@ -75,7 +75,7 @@ CallPositionFrames CallPositionFrames::propagate(
context,
source_register_types,
source_constant_arguments,
callee_interval,
class_interval_context,
caller_class_interval);
result.update(
propagated.callee_port(), [&propagated](CalleePortFrames* frames) {
Expand Down Expand Up @@ -142,7 +142,7 @@ CallPositionFrames CallPositionFrames::attach_position(
/* call_position */ position,
// TODO(T158171922): Re-visit what the appropriate interval
// should be when implementing class intervals.
frame.callee_interval(),
frame.class_interval_context(),
/* distance */ 0,
frame.origins(),
frame.field_origins(),
Expand Down Expand Up @@ -213,7 +213,7 @@ CallPositionFrames::map_positions(
frame.call_info(),
frame.field_callee(),
call_position,
frame.callee_interval(),
frame.class_interval_context(),
frame.distance(),
frame.origins(),
frame.field_origins(),
Expand Down
2 changes: 1 addition & 1 deletion source/CallPositionFrames.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class CallPositionFrames final : public FramesMap<
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
const std::vector<std::optional<std::string>>& source_constant_arguments,
const CalleeInterval& callee_interval,
const CallClassIntervalContext& class_interval_context,
const ClassIntervals::Interval& caller_class_interval) const;

/* Return the set of leaf frames with the given position. */
Expand Down
4 changes: 2 additions & 2 deletions source/CalleeFrames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ CalleeFrames CalleeFrames::propagate(
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
const std::vector<std::optional<std::string>>& source_constant_arguments,
const CalleeInterval& callee_interval,
const CallClassIntervalContext& class_interval_context,
const ClassIntervals::Interval& caller_class_interval) const {
if (is_bottom()) {
return CalleeFrames::bottom();
Expand All @@ -86,7 +86,7 @@ CalleeFrames CalleeFrames::propagate(
context,
source_register_types,
source_constant_arguments,
callee_interval,
class_interval_context,
caller_class_interval));
}

Expand Down
2 changes: 1 addition & 1 deletion source/CalleeFrames.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class CalleeFrames final : public FramesMap<
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
const std::vector<std::optional<std::string>>& source_constant_arguments,
const CalleeInterval& callee_interval,
const CallClassIntervalContext& class_interval_context,
const ClassIntervals::Interval& caller_class_interval) const;

/**
Expand Down
42 changes: 0 additions & 42 deletions source/CalleeInterval.cpp

This file was deleted.

6 changes: 3 additions & 3 deletions source/CalleePortFrames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ CalleePortFrames CalleePortFrames::propagate(
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
const std::vector<std::optional<std::string>>& source_constant_arguments,
const CalleeInterval& callee_interval,
const CallClassIntervalContext& class_interval_context,
const ClassIntervals::Interval& caller_class_interval) const {
if (is_bottom()) {
return CalleePortFrames::bottom();
Expand Down Expand Up @@ -233,7 +233,7 @@ CalleePortFrames CalleePortFrames::propagate(
maximum_source_sink_distance,
context,
source_register_types,
callee_interval,
class_interval_context,
caller_class_interval);
} else {
std::vector<const Feature*> via_type_of_features_added;
Expand All @@ -247,7 +247,7 @@ CalleePortFrames CalleePortFrames::propagate(
source_register_types,
source_constant_arguments,
via_type_of_features_added,
callee_interval,
class_interval_context,
caller_class_interval);
}

Expand Down
2 changes: 1 addition & 1 deletion source/CalleePortFrames.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class CalleePortFrames final : public sparta::AbstractDomain<CalleePortFrames> {
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
const std::vector<std::optional<std::string>>& source_constant_arguments,
const CalleeInterval& callee_interval,
const CallClassIntervalContext& class_interval_context,
const ClassIntervals::Interval& caller_class_interval) const;

/**
Expand Down
2 changes: 1 addition & 1 deletion source/ForwardTaintTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ void check_flows_to_array_allocation(
/* call_info */ CallInfo::origin(),
/* field_callee */ nullptr,
/* call_position */ position,
/* callee_interval */ CalleeInterval(),
/* class_interval_context */ CallClassIntervalContext(),
/* distance */ 1,
/* origins */ MethodSet{array_allocation_method},
/* field_origins */ {},
Expand Down
16 changes: 8 additions & 8 deletions source/Frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Frame::Frame(const TaintConfig& config)
config.callee(),
config.field_callee(),
config.call_position(),
config.callee_interval(),
config.class_interval_context(),
config.distance(),
config.origins(),
config.field_origins(),
Expand Down Expand Up @@ -92,7 +92,7 @@ bool Frame::leq(const Frame& other) const {
return kind_ == other.kind_ && callee_port_ == other.callee_port_ &&
callee_ == other.callee_ && call_position_ == other.call_position_ &&
call_info_ == other.call_info_ && distance_ >= other.distance_ &&
callee_interval_ == other.callee_interval_ &&
class_interval_context_ == other.class_interval_context_ &&
origins_.leq(other.origins_) &&
field_origins_.leq(other.field_origins_) &&
inferred_features_.leq(other.inferred_features_) &&
Expand All @@ -114,7 +114,7 @@ bool Frame::equals(const Frame& other) const {
return kind_ == other.kind_ && callee_port_ == other.callee_port_ &&
callee_ == other.callee_ && call_position_ == other.call_position_ &&
call_info_ == other.call_info_ &&
callee_interval_ == other.callee_interval_ &&
class_interval_context_ == other.class_interval_context_ &&
distance_ == other.distance_ && origins_ == other.origins_ &&
field_origins_ == other.field_origins_ &&
inferred_features_ == other.inferred_features_ &&
Expand All @@ -140,7 +140,7 @@ void Frame::join_with(const Frame& other) {
mt_assert(call_position_ == other.call_position_);
mt_assert(callee_port_ == other.callee_port_);
mt_assert(call_info_ == other.call_info_);
mt_assert(callee_interval_ == other.callee_interval_);
mt_assert(class_interval_context_ == other.class_interval_context_);

distance_ = std::min(distance_, other.distance_);
origins_.join_with(other.origins_);
Expand Down Expand Up @@ -199,7 +199,7 @@ Frame Frame::update_with_propagation_trace(
/* field_callee */ nullptr, // Since propagate is only called at method
// callsites and not field accesses
propagation_frame.call_position_,
callee_interval_,
class_interval_context_,
propagation_frame.distance_,
propagation_frame.origins_,
field_origins_,
Expand Down Expand Up @@ -406,8 +406,8 @@ Json::Value Frame::to_json(ExportOriginsMode export_origins_mode) const {
value["output_paths"] = output_paths_value;
}

if (!callee_interval_.is_default()) {
auto interval_json = callee_interval_.to_json();
if (!class_interval_context_.is_default()) {
auto interval_json = class_interval_context_.to_json();
for (const auto& member : interval_json.getMemberNames()) {
value[member] = interval_json[member];
}
Expand Down Expand Up @@ -435,7 +435,7 @@ std::ostream& operator<<(std::ostream& out, const Frame& frame) {
if (frame.call_position_ != nullptr) {
out << ", call_position=" << show(frame.call_position_);
}
out << ", callee_interval=" << show(frame.callee_interval_);
out << ", class_interval_context=" << show(frame.class_interval_context_);
out << ", call_info=" << frame.call_info_.to_trace_string();
if (frame.distance_ != 0) {
out << ", distance=" << frame.distance_;
Expand Down
Loading

0 comments on commit cc5057b

Please sign in to comment.