Skip to content

Commit

Permalink
Type inference for lists and lists-of-lists
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed Mar 7, 2016
1 parent db34836 commit fba2ab8
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 34 deletions.
1 change: 1 addition & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ set(CYTHON_EXTENSIONS
config
error
parquet
scalar
schema
)

Expand Down
10 changes: 9 additions & 1 deletion python/arrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@

# flake8: noqa

from arrow.array import Array, from_list, total_allocated_bytes
from arrow.array import (Array, from_pylist, total_allocated_bytes,
BooleanArray, NumericArray,
Int8Array, UInt8Array,
ListArray, StringArray)

from arrow.error import ArrowException

from arrow.scalar import ArrayValue, NA, Scalar

from arrow.schema import (null, bool_,
int8, int16, int32, int64,
uint8, uint16, uint32, uint64,
Expand Down
7 changes: 7 additions & 0 deletions python/arrow/array.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,23 @@
from arrow.includes.common cimport shared_ptr
from arrow.includes.arrow cimport CArray, LogicalType

from arrow.scalar import NA

from arrow.schema cimport DataType

cdef extern from "Python.h":
int PySlice_Check(object)

cdef class Array:
cdef:
shared_ptr[CArray] sp_array
CArray* ap

cdef readonly:
DataType type

cdef init(self, const shared_ptr[CArray]& sp_array)
cdef _getitem(self, int i)


cdef class BooleanArray(Array):
Expand Down
40 changes: 39 additions & 1 deletion python/arrow/array.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ cimport arrow.includes.pyarrow as pyarrow
from arrow.compat import frombytes, tobytes
from arrow.error cimport check_status

from arrow.scalar import NA

def total_allocated_bytes():
cdef MemoryPool* pool = pyarrow.GetMemoryPool()
Expand All @@ -35,6 +36,7 @@ cdef class Array:

cdef init(self, const shared_ptr[CArray]& sp_array):
self.sp_array = sp_array
self.ap = sp_array.get()
self.type = DataType()
self.type.init(self.sp_array.get().type())

Expand All @@ -46,6 +48,42 @@ cdef class Array:
def __len__(self):
return self.sp_array.get().length()

def isnull(self):
raise NotImplemented

def __getitem__(self, key):
cdef:
Py_ssize_t n = len(self)

if PySlice_Check(key):
start = key.start or 0
while start < 0:
start += n

stop = key.stop if key.stop is not None else n
while stop < 0:
stop += n

step = key.step or 1
if step != 1:
raise NotImplementedError
else:
return self.slice(start, stop)

while key < 0:
key += len(self)

if self.ap.IsNull(key):
return NA
else:
return self._getitem(key)

cdef _getitem(self, int i):
raise NotImplementedError

def slice(self, start, end):
pass


cdef class NullArray(Array):
pass
Expand Down Expand Up @@ -121,7 +159,7 @@ cdef object box_arrow_array(const shared_ptr[CArray]& sp_array):
return arr


def from_list(object list_obj, type=None):
def from_pylist(object list_obj, type=None):
"""
Convert Python list to Arrow array
"""
Expand Down
4 changes: 3 additions & 1 deletion python/arrow/error.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from arrow.includes.common cimport c_string

from arrow.compat import frombytes

class ArrowException(Exception):
pass

Expand All @@ -25,4 +27,4 @@ cdef check_status(const Status& status):
return

cdef c_string c_message = status.ToString()
return ArrowException(c_message)
raise ArrowException(frombytes(c_message))
18 changes: 13 additions & 5 deletions python/arrow/scalar.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@ from arrow.includes.arrow cimport CArray, CListArray

from arrow.schema cimport DataType

cdef class ScalarValue:
cdef class Scalar:
cdef readonly:
DataType type


cdef class NAType(Scalar):
pass


cdef class ArrayValue(Scalar):
cdef:
shared_ptr[CArray] array
int index
DataType type


cdef class Int8Value:
cdef class Int8Value(ArrayValue):
pass


cdef class ListValue:
cdef class ListValue(ArrayValue):
pass


cdef class StringValue:
cdef class StringValue(ArrayValue):
pass
28 changes: 28 additions & 0 deletions python/arrow/scalar.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import arrow.schema as schema

cdef class NAType(Scalar):

def __cinit__(self):
self.type = schema.null()

def __repr__(self):
return 'NA'

NA = NAType()
26 changes: 26 additions & 0 deletions python/arrow/tests/test_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from arrow.compat import unittest
import arrow


class TestArrayAPI(unittest.TestCase):

def test_getitem_NA(self):
arr = arrow.from_pylist([1, None, 2])
assert arr[1] is arrow.NA
28 changes: 23 additions & 5 deletions python/arrow/tests/test_convert_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,27 @@ def test_boolean(self):
pass

def test_empty_list(self):
arr = arrow.from_list([])
arr = arrow.from_pylist([])
assert len(arr) == 0
assert arr.null_count == 0
assert arr.type == arrow.null()

def test_all_none(self):
arr = arrow.from_list([None, None])
arr = arrow.from_pylist([None, None])
assert len(arr) == 2
assert arr.null_count == 2
assert arr.type == arrow.null()

def test_integer(self):
arr = arrow.from_list([1, None, 3, None])
arr = arrow.from_pylist([1, None, 3, None])
assert len(arr) == 4
assert arr.null_count == 2
assert arr.type == arrow.int64()

def test_garbage_collection(self):
import gc
bytes_before = arrow.total_allocated_bytes()
arrow.from_list([1, None, 3, None])
arrow.from_pylist([1, None, 3, None])
gc.collect()
assert arrow.total_allocated_bytes() == bytes_before

Expand All @@ -56,4 +56,22 @@ def test_string(self):
pass

def test_list_of_int(self):
pass
data = [[1, 2, 3], [], None, [1, 2]]
arr = arrow.from_pylist(data)
# assert len(arr) == 4
# assert arr.null_count == 1
assert arr.type == arrow.list_(arrow.int64())

def test_mixed_nesting_levels(self):
arrow.from_pylist([1, 2, None])
arrow.from_pylist([[1], [2], None])
arrow.from_pylist([[1], [2], [None]])

with self.assertRaises(arrow.ArrowException):
arrow.from_pylist([1, 2, [1]])

with self.assertRaises(arrow.ArrowException):
arrow.from_pylist([1, 2, []])

with self.assertRaises(arrow.ArrowException):
arrow.from_pylist([[1], [2], [None, [1]]])
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def get_ext_built(self, name):
return name + suffix

def get_cmake_cython_names(self):
return ['array', 'config', 'error', 'parquet', 'schema']
return ['array', 'config', 'error', 'parquet', 'scalar', 'schema']

def get_names(self):
return self._found_names
Expand Down
Loading

0 comments on commit fba2ab8

Please sign in to comment.