-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Design doc for operator attribute #2606
Changes from all commits
cdd28f7
87e3820
581ce7d
306dcfe
0d9b9d3
e3a63d7
ba54a0c
18dd0ad
908c8c1
b901e3b
7250a92
3090785
2bde865
b90a3a6
d21d486
911113d
1f35526
224c6a4
1895f06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Design Doc: Operator Attributes | ||
|
||
## Background | ||
|
||
An operator could have attributes. For example, CosineOp could have a float typed attribute scale, which changes the output range from [-1,1] to [-scale,scale]. The default value of scale is `1.0`. | ||
|
||
Attributes is defined by a name and a type. An instance of an attribute has a value of that type. | ||
|
||
As part of the network description, attribute need to be serialized. So we need a protobuf message that describes an attribute, say `Attribute`. | ||
|
||
An operator could parse the Attribute and save them into its private data member. | ||
|
||
## Protobuf Implementation | ||
|
||
There are two frameworks implement `Attribute` concept in `protobuf`. They are [`caffe2`](https://github.com/caffe2/caffe2/blob/master/caffe2/proto/caffe2.proto#L98) and [`tensorflow`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto#L16). | ||
|
||
* Caffe2 uses `proto2` syntax. It treats all attributes as a list, and each attribute contains a `name`. Each time caffe2 read an attribute is searching a variable in a list. It is slow if the number of attributes is large. Caffe2 also mark all field as `optional`. It doesn't ensure `one of` attribute value is set. | ||
* By using `proto3` syntax in tensorflow, the attribute implementation in tensorflow is using `map`, and `oneof` keywords. Looking up from attribute map in tensorflow is fast. | ||
|
||
Paddle is using `protobuf 3` as its dependency library. By simplify `tensorflow`'s implementation, Paddle's Attribute protobuf message schema could be | ||
|
||
```protobuf | ||
message Attribute { | ||
message ListValue { | ||
repeated int32 ints = 1; | ||
repeated float floats = 2; | ||
repeated string strings = 3; | ||
} | ||
|
||
oneof value { | ||
ListValue list = 1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The oneof directive doesn't seem help much here -- it can locate a field in syntax = "proto3";
message Attribute {
enum Type {
INTS = 0;
FLOATS = 1;
STRINGS = 2;
INT = 3;
FLOAT = 4;
STRING = 5;
}
Type type = 1;
repeated int32 ints = 2;
repeated float floats = 3;
repeated string strings = 4;
int32 int = 5;
float float = 6;
string string = 7;
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax = "proto3";
message Attribute {
repeated int32 ints = 1;
repeated float floats = 2;
repeated string strings = 3;
optional int32 int = 4;
optional float float = 5;
optinoal string string = 6;
} Maybe the |
||
int32 i = 2; | ||
float f = 3; | ||
string s = 4; | ||
} | ||
} | ||
``` | ||
|
||
In `OperatorDescription` message, there should be a field like this: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OperatorDescription => OperatorDesc ? |
||
|
||
```protobuf | ||
message OperatorDescription { | ||
map<string, Attribute> attrs; | ||
} | ||
``` | ||
|
||
## CPP implementation | ||
|
||
### AttributeReader | ||
|
||
In CPP, it should be a helper class for reading `map<string, Attribute>`. The reading methods in that helper class should accept a template parameter, which is the type of Attribute. That helper class we named `AttributeReader`. | ||
|
||
The interface of `AttributeReader` is like this: | ||
|
||
```cpp | ||
using AttributeMap = google::protobuf::Map<std::string, Attribute>; | ||
class AttributeReader { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AttributeReader => Attributes This is not a reader, it holds and searches attributes as well. |
||
public: | ||
explicit AttributeReader(const AttributeMap& attrs) : attrs_(attrs) {} | ||
|
||
// Get a plain type T attribute, which name is `name` | ||
template <typename T> | ||
T Get(const std::string& name) const; | ||
|
||
// Get attribute with a array of type T, which name is `name` | ||
template <typename T> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if an attribute was marked int in the protobuf message, but the user tries to retrieve its string value? How can we check and find such kind of errors? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we use |
||
void GetArray(const std::string& name, std::vector<T>* vec) const; | ||
|
||
// Is that `name` attribute with type T in map or not. | ||
// T could be int, float, string and std::vector of them | ||
template <typename T> | ||
bool Contains(const std::string& name) const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Contains => IsType For example, if (IsType<std::vector<string>>("name")) {
...
} Actually, maybe it would be easier to have class Attributes {
public:
bool Has(const std::string& name) const {
return attrs_.find(name) != attrs_.end();
}
const std::type_info& Type() const(const std::string& name) const {
PADDLE_ENFORCE(Has(name));
switch (attrs_[name].type()) {
case paddle::framework::Attribute::INTS:
return typeid(int);
...
}
}
}; |
||
|
||
private: | ||
const AttributeMap& attrs_; | ||
}; | ||
``` | ||
|
||
### Attribute in Operator | ||
|
||
Each operator parse and store its attribute into private member data when `InitializeAttribute`. That method will be invoked by `CreateOperator `. User can use `PADDLE_ENFORCE` to validate attribute. Also, use `Contains` method, user can set default value of attributes. | ||
|
||
```cpp | ||
class OperatorBase { | ||
public: | ||
virtual void InitializeAttribute(const AttributeReader& attrs) = 0; | ||
}; | ||
|
||
class CosineOp : public OperatorBase { | ||
public: | ||
void InitializeAttribute(const AttributeReader& attrs) { | ||
if (attrs.Contains<float>("scale")) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't look like a smart idea -- every time we change an attribute in the |
||
scale_ = attrs.Get<float>("scale"); | ||
PADDLE_ENFORCE(scale_ > 0.0f, "Scale of consine op should be larger than 0.0"); | ||
} | ||
} | ||
|
||
private: | ||
float scale_ {1.0}; | ||
}; | ||
``` | ||
|
||
`InitializeAttribute` will be invoked by `CreateOperator`. Since `InitializeAttribute ` could throw an EnforceNotMet, a `unique_ptr` is used to make code exception-safe. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that this document is not complete. I cannot find the part describing macros that define attributes of an operator? |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I am curious about how the Python API can fill in the protobuf field |
||
```cpp | ||
std::unique_ptr<OperatorBase> CreateOperator(const OperatorDescription& desc) { | ||
std::unique_ptr<OperatorBase> op(OperatorRegister.Create( | ||
desc.type(), desc.inputs(), desc.outputs())); | ||
op->InitializeAttribute(AttributeReader(desc.attrs())); | ||
return std::move(op); | ||
} | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
syntax="proto3"; | ||
package paddle.framework; | ||
|
||
// Attribute for OperatorDescription. | ||
// It represent a variant type of value. | ||
// The int, float, string and repeated of them are supported. | ||
message Attribute { | ||
message ListValue { | ||
repeated int32 ints = 1; | ||
repeated float floats = 2; | ||
repeated string strings = 3; | ||
} | ||
|
||
oneof value { | ||
// Since proto3 not support repeated filed in `oneof`, we use | ||
// ListValue to support repeated. | ||
ListValue list = 1; | ||
|
||
int32 i = 2; | ||
float f = 3; | ||
string s = 4; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include <google/protobuf/map.h> | ||
#include <paddle/framework/attribute.pb.h> | ||
#include <paddle/framework/enforce.h> | ||
#include <algorithm> | ||
#include <type_traits> | ||
#include <vector> | ||
|
||
namespace paddle { | ||
namespace framework { | ||
using AttributeMap = google::protobuf::Map<std::string, Attribute>; | ||
class AttributeReader { | ||
public: | ||
explicit AttributeReader(const AttributeMap& attrs) : attrs_(attrs) {} | ||
|
||
/** | ||
* @brief Contains a attribute with name and type T. | ||
* | ||
* The example code like | ||
* @code{cpp} | ||
* AttributeReader reader; | ||
* | ||
* assert(reader.Contain<int>("SomeIntValue")==true); | ||
* assert(reader.Contain<float>("SomeIntValue")==false); | ||
* assert(reader.Contain<std::vector<int>>("SomeIntList")==true); | ||
* @endcode{cpp} | ||
* | ||
* @tparam T Attribute Type, could be {int, float, string and std::vector of | ||
* them}. | ||
* @param name attribute name | ||
* @return true if contain an attribute with name and type T, false if Type | ||
* mismatch or not contains that name. | ||
*/ | ||
template <typename T> | ||
bool Contains(const std::string& name) const; | ||
|
||
/** | ||
* @brief Get Attribute value. Not support std::vector. If want to return a | ||
* std::vector, use `GetArray` | ||
* @tparam T could be int, float, string. | ||
* @param name attribute name. | ||
* @return Value | ||
* @throw If attribute is not found or type mismatch, an EnforceNotMet will | ||
* throw | ||
*/ | ||
template <typename T> | ||
T Get(const std::string& name) const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just discuss, do we need a T Get(const std::string& name, const T& default) const; to set a default value There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see the design doc, people can use |
||
|
||
/** | ||
* @brief Get Attribute array values. | ||
* @tparam T could be int, float, string. | ||
* @param name attribute name. | ||
* @param vec the return vector. Must be empty. | ||
* @throw If attribute is not found, or type mismatch, or vec is not empty, an | ||
* EnforceNotMet will throw | ||
*/ | ||
template <typename T> | ||
void GetArray(const std::string& name, std::vector<T>* vec) const; | ||
|
||
private: | ||
const AttributeMap& attrs_; | ||
}; | ||
|
||
/// Implementation of Contain | ||
namespace details { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why named There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we need define some function that should not be used by the user but can only be defined in the header file. So mark them in |
||
inline const ::paddle::framework::Attribute* GetField(const AttributeMap& attrs, | ||
const std::string& name) { | ||
auto it = attrs.find(name); | ||
if (it == attrs.end()) { | ||
return nullptr; | ||
} else { | ||
return &it->second; | ||
} | ||
} | ||
|
||
template <typename T> | ||
inline bool IsType(const ::paddle::framework::Attribute* attr); | ||
|
||
template <typename T, bool IsArray> | ||
struct ContainsImpl {}; | ||
|
||
template <typename T> | ||
struct ContainsImpl<T, false> { | ||
bool operator()(const AttributeMap& attrs, const std::string& name) { | ||
auto attr = GetField(attrs, name); | ||
if (attr) { | ||
return details::IsType<T>(attr); | ||
} else { | ||
return false; | ||
} | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct ContainsImpl<T, true> { | ||
bool operator()(const AttributeMap& attrs, const std::string& name) { | ||
auto attr = GetField(attrs, name); | ||
if (attr) { | ||
return attr->has_list(); | ||
} else { | ||
return false; | ||
} | ||
} | ||
}; | ||
} // namespace details | ||
|
||
template <typename T> | ||
bool AttributeReader::Contains(const std::string& name) const { | ||
constexpr bool is_vec = std::is_same<T, std::vector<int>>::value || | ||
std::is_same<T, std::vector<float>>::value || | ||
std::is_same<T, std::vector<std::string>>::value; | ||
return details::ContainsImpl<T, is_vec>()(attrs_, name); | ||
} | ||
|
||
#define ATTR_READER_ISTYPE_IMPL(T, CASE) \ | ||
namespace details { \ | ||
template <> \ | ||
inline bool IsType<T>(const ::paddle::framework::Attribute* attr) { \ | ||
return attr->value_case() == CASE; \ | ||
} \ | ||
} | ||
|
||
ATTR_READER_ISTYPE_IMPL(int, ::paddle::framework::Attribute::kI); | ||
ATTR_READER_ISTYPE_IMPL(float, ::paddle::framework::Attribute::kF); | ||
ATTR_READER_ISTYPE_IMPL(std::string, ::paddle::framework::Attribute::kS); | ||
|
||
#undef ATTR_READER_ISTYPE_IMPL | ||
|
||
/// Implementation of Get | ||
namespace details { | ||
template <typename T> | ||
inline T GetValue(const ::paddle::framework::Attribute* attr); | ||
} | ||
|
||
template <typename T> | ||
T AttributeReader::Get(const std::string& name) const { | ||
auto attr = details::GetField(attrs_, name); | ||
PADDLE_ENFORCE(attr != nullptr, "Attribute %s not found", name); | ||
PADDLE_ENFORCE(details::IsType<T>(attr), | ||
"Attribute type mismatch. Expected %s", typeid(T).name()); | ||
return details::GetValue<T>(attr); | ||
} | ||
|
||
#define ATTR_READER_GETVALUE_IMPL(T, FIELD) \ | ||
namespace details { \ | ||
template <> \ | ||
inline T GetValue<T>(const ::paddle::framework::Attribute* attr) { \ | ||
return attr->FIELD(); \ | ||
} \ | ||
} | ||
|
||
ATTR_READER_GETVALUE_IMPL(int, i); | ||
ATTR_READER_GETVALUE_IMPL(float, f); | ||
ATTR_READER_GETVALUE_IMPL(std::string, s); | ||
|
||
#undef ATTR_READER_GETVALUE_IMPL | ||
|
||
/// Implementation of GetArray | ||
#define ATTR_GETARRAY_IMPL(T, FIELD) \ | ||
template <> \ | ||
void AttributeReader::GetArray<T>(const std::string& name, \ | ||
std::vector<T>* vec) const { \ | ||
PADDLE_ENFORCE(vec->empty(), "Input vector should be empty"); \ | ||
auto attr = details::GetField(attrs_, name); \ | ||
PADDLE_ENFORCE(attr != nullptr, "Attribute %s not found", name); \ | ||
PADDLE_ENFORCE(attr->has_list(), "Attribute %s is not array", name); \ | ||
auto& field = attr->list().FIELD(); \ | ||
vec->reserve(field.size()); \ | ||
std::copy(field.begin(), field.end(), std::back_inserter(*vec)); \ | ||
} | ||
|
||
ATTR_GETARRAY_IMPL(int, ints); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to be with the same name style with ATTR_GETARRAY_IMPL => ATTR_READER_GETARRAY_IMPL |
||
ATTR_GETARRAY_IMPL(float, floats); | ||
ATTR_GETARRAY_IMPL(std::string, strings); | ||
|
||
#undef ATTR_GETARRAY_IMPL | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surprise to know that Caffe2 searches attributes in a protobuf list. The right way is to load from protobuf field
repeated Attrbribute attrs = x;
into C++ data structuremap<string, Attribute>
and search in the C++ data structure.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's my mistake, Caffe2 is loading all attribute into memory first.