-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-6055] [PySpark] fix incorrect __eq__ of DataType #4808
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
import warnings | ||
import json | ||
import re | ||
import weakref | ||
from array import array | ||
from operator import itemgetter | ||
|
||
|
@@ -42,8 +43,7 @@ def __hash__(self): | |
return hash(str(self)) | ||
|
||
def __eq__(self, other): | ||
return (isinstance(other, self.__class__) and | ||
self.__dict__ == other.__dict__) | ||
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
@@ -64,6 +64,8 @@ def json(self): | |
sort_keys=True) | ||
|
||
|
||
# This singleton pattern does not work with pickle, you will get | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to have singletons that pickled properly if we implemented a custom There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We discussed this, the old implementation is a bit of over optimization, I think. For PrimitiveTypeSingleton, the |
||
# another object after pickle and unpickle | ||
class PrimitiveTypeSingleton(type): | ||
|
||
"""Metaclass for PrimitiveType""" | ||
|
@@ -82,10 +84,6 @@ class PrimitiveType(DataType): | |
|
||
__metaclass__ = PrimitiveTypeSingleton | ||
|
||
def __eq__(self, other): | ||
# because they should be the same object | ||
return self is other | ||
|
||
|
||
class NullType(PrimitiveType): | ||
|
||
|
@@ -242,11 +240,12 @@ def __init__(self, elementType, containsNull=True): | |
:param elementType: the data type of elements. | ||
:param containsNull: indicates whether the list contains None values. | ||
|
||
>>> ArrayType(StringType) == ArrayType(StringType, True) | ||
>>> ArrayType(StringType()) == ArrayType(StringType(), True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a breaking API change? Or were the old doctests showing incorrect usage of the API? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Old tests are incorrect. |
||
True | ||
>>> ArrayType(StringType, False) == ArrayType(StringType) | ||
>>> ArrayType(StringType(), False) == ArrayType(StringType()) | ||
False | ||
""" | ||
assert isinstance(elementType, DataType), "elementType should be DataType" | ||
self.elementType = elementType | ||
self.containsNull = containsNull | ||
|
||
|
@@ -292,13 +291,15 @@ def __init__(self, keyType, valueType, valueContainsNull=True): | |
:param valueContainsNull: indicates whether values contains | ||
null values. | ||
|
||
>>> (MapType(StringType, IntegerType) | ||
... == MapType(StringType, IntegerType, True)) | ||
>>> (MapType(StringType(), IntegerType()) | ||
... == MapType(StringType(), IntegerType(), True)) | ||
True | ||
>>> (MapType(StringType, IntegerType, False) | ||
... == MapType(StringType, FloatType)) | ||
>>> (MapType(StringType(), IntegerType(), False) | ||
... == MapType(StringType(), FloatType())) | ||
False | ||
""" | ||
assert isinstance(keyType, DataType), "keyType should be DataType" | ||
assert isinstance(valueType, DataType), "valueType should be DataType" | ||
self.keyType = keyType | ||
self.valueType = valueType | ||
self.valueContainsNull = valueContainsNull | ||
|
@@ -348,13 +349,14 @@ def __init__(self, name, dataType, nullable=True, metadata=None): | |
to simple type that can be serialized to JSON | ||
automatically | ||
|
||
>>> (StructField("f1", StringType, True) | ||
... == StructField("f1", StringType, True)) | ||
>>> (StructField("f1", StringType(), True) | ||
... == StructField("f1", StringType(), True)) | ||
True | ||
>>> (StructField("f1", StringType, True) | ||
... == StructField("f2", StringType, True)) | ||
>>> (StructField("f1", StringType(), True) | ||
... == StructField("f2", StringType(), True)) | ||
False | ||
""" | ||
assert isinstance(dataType, DataType), "dataType should be DataType" | ||
self.name = name | ||
self.dataType = dataType | ||
self.nullable = nullable | ||
|
@@ -393,16 +395,17 @@ class StructType(DataType): | |
def __init__(self, fields): | ||
"""Creates a StructType | ||
|
||
>>> struct1 = StructType([StructField("f1", StringType, True)]) | ||
>>> struct2 = StructType([StructField("f1", StringType, True)]) | ||
>>> struct1 = StructType([StructField("f1", StringType(), True)]) | ||
>>> struct2 = StructType([StructField("f1", StringType(), True)]) | ||
>>> struct1 == struct2 | ||
True | ||
>>> struct1 = StructType([StructField("f1", StringType, True)]) | ||
>>> struct2 = StructType([StructField("f1", StringType, True), | ||
... [StructField("f2", IntegerType, False)]]) | ||
>>> struct1 = StructType([StructField("f1", StringType(), True)]) | ||
>>> struct2 = StructType([StructField("f1", StringType(), True), | ||
... StructField("f2", IntegerType(), False)]) | ||
>>> struct1 == struct2 | ||
False | ||
""" | ||
assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" | ||
self.fields = fields | ||
|
||
def simpleString(self): | ||
|
@@ -505,6 +508,9 @@ def __eq__(self, other): | |
|
||
def _parse_datatype_json_string(json_string): | ||
"""Parses the given data type JSON string. | ||
>>> import pickle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this is a regression test for the singleton There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. e.g. maybe put it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense, will do. |
||
>>> LongType() == pickle.loads(pickle.dumps(LongType())) | ||
True | ||
>>> def check_datatype(datatype): | ||
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) | ||
... python_datatype = _parse_datatype_json_string(scala_datatype.json()) | ||
|
@@ -786,8 +792,24 @@ def _merge_type(a, b): | |
return a | ||
|
||
|
||
def _need_converter(dataType): | ||
if isinstance(dataType, StructType): | ||
return True | ||
elif isinstance(dataType, ArrayType): | ||
return _need_converter(dataType.elementType) | ||
elif isinstance(dataType, MapType): | ||
return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) | ||
elif isinstance(dataType, NullType): | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
def _create_converter(dataType): | ||
"""Create an converter to drop the names of fields in obj """ | ||
if not _need_converter(dataType): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this particular call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's useful for ArrayType and MapType, can it didnot can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, right; I overlooked that. Makes sense. |
||
return lambda x: x | ||
|
||
if isinstance(dataType, ArrayType): | ||
conv = _create_converter(dataType.elementType) | ||
return lambda row: map(conv, row) | ||
|
@@ -806,13 +828,17 @@ def _create_converter(dataType): | |
# dataType must be StructType | ||
names = [f.name for f in dataType.fields] | ||
converters = [_create_converter(f.dataType) for f in dataType.fields] | ||
convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) | ||
|
||
def convert_struct(obj): | ||
if obj is None: | ||
return | ||
|
||
if isinstance(obj, (tuple, list)): | ||
return tuple(conv(v) for v, conv in zip(obj, converters)) | ||
if convert_fields: | ||
return tuple(conv(v) for v, conv in zip(obj, converters)) | ||
else: | ||
return tuple(obj) | ||
|
||
if isinstance(obj, dict): | ||
d = obj | ||
|
@@ -821,7 +847,10 @@ def convert_struct(obj): | |
else: | ||
raise ValueError("Unexpected obj: %s" % obj) | ||
|
||
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) | ||
if convert_fields: | ||
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) | ||
else: | ||
return tuple([d.get(name) for name in names]) | ||
|
||
return convert_struct | ||
|
||
|
@@ -871,20 +900,20 @@ def _parse_field_abstract(s): | |
Parse a field in schema abstract | ||
|
||
>>> _parse_field_abstract("a") | ||
StructField(a,None,true) | ||
StructField(a,NullType,true) | ||
>>> _parse_field_abstract("b(c d)") | ||
StructField(b,StructType(...c,None,true),StructField(d... | ||
StructField(b,StructType(...c,NullType,true),StructField(d... | ||
>>> _parse_field_abstract("a[]") | ||
StructField(a,ArrayType(None,true),true) | ||
StructField(a,ArrayType(NullType,true),true) | ||
>>> _parse_field_abstract("a{[]}") | ||
StructField(a,MapType(None,ArrayType(None,true),true),true) | ||
StructField(a,MapType(NullType,ArrayType(NullType,true),true),true) | ||
""" | ||
if set(_BRACKETS.keys()) & set(s): | ||
idx = min((s.index(c) for c in _BRACKETS if c in s)) | ||
name = s[:idx] | ||
return StructField(name, _parse_schema_abstract(s[idx:]), True) | ||
else: | ||
return StructField(s, None, True) | ||
return StructField(s, NullType(), True) | ||
|
||
|
||
def _parse_schema_abstract(s): | ||
|
@@ -898,11 +927,11 @@ def _parse_schema_abstract(s): | |
>>> _parse_schema_abstract("c{} d{a b}") | ||
StructType...c,MapType...d,MapType...a...b... | ||
>>> _parse_schema_abstract("a b(t)").fields[1] | ||
StructField(b,StructType(List(StructField(t,None,true))),true) | ||
StructField(b,StructType(List(StructField(t,NullType,true))),true) | ||
""" | ||
s = s.strip() | ||
if not s: | ||
return | ||
return NullType() | ||
|
||
elif s.startswith('('): | ||
return _parse_schema_abstract(s[1:-1]) | ||
|
@@ -911,7 +940,7 @@ def _parse_schema_abstract(s): | |
return ArrayType(_parse_schema_abstract(s[1:-1]), True) | ||
|
||
elif s.startswith('{'): | ||
return MapType(None, _parse_schema_abstract(s[1:-1])) | ||
return MapType(NullType(), _parse_schema_abstract(s[1:-1])) | ||
|
||
parts = _split_schema_abstract(s) | ||
fields = [_parse_field_abstract(p) for p in parts] | ||
|
@@ -931,7 +960,7 @@ def _infer_schema_type(obj, dataType): | |
>>> _infer_schema_type(row, schema) | ||
StructType...a,ArrayType...b,MapType(StringType,...c,LongType... | ||
""" | ||
if dataType is None: | ||
if dataType is NullType(): | ||
return _infer_type(obj) | ||
|
||
if not obj: | ||
|
@@ -1037,8 +1066,7 @@ def _verify_type(obj, dataType): | |
for v, f in zip(obj, dataType.fields): | ||
_verify_type(v, f.dataType) | ||
|
||
|
||
_cached_cls = {} | ||
_cached_cls = weakref.WeakValueDictionary() | ||
|
||
|
||
def _restore_object(dataType, obj): | ||
|
@@ -1233,8 +1261,7 @@ def __new__(self, *args, **kwargs): | |
elif kwargs: | ||
# create row objects | ||
names = sorted(kwargs.keys()) | ||
values = tuple(kwargs[n] for n in names) | ||
row = tuple.__new__(self, values) | ||
row = tuple.__new__(self, [kwargs[n] for n in names]) | ||
row.__FIELDS__ = names | ||
return row | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are duplicated, also in types.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, good catch.