Skip to content

Commit

Permalink
feat: add variable length loop kernels (#3003)
Browse files Browse the repository at this point in the history
* feat: add variable length kernels

* fix: spec kernel errors

* feat: add awkward_ListArray_broadcast_tooffsets

* fix: awkward_ListArray_compact_offsets kernel

* test: remove XFAIL

* style: pre-commit fixes

* feat: add awkward_ListArray_getitem_jagged_descend.cu

* feat: add awkward_ListArray_getitem_jagged_numvalid

* feat: add awkward_ListArray_getitem_next_range_spreadadvanced

* feat: add awkward_ListOffsetArray_rpad_length_axis1

* feat: add awkward_ListOffsetArray_toRegularArray.cpp

* feat: add awkward_ListArray_localindex

* feat: add awkward_ListOffsetArray_reduce_local_nextparents_64.cu

* feat: add awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64.cu

* feat: add awkward_UnionArray_regular_index_getsize

* refactor: remove _a from the name of the kernel

* test: generate tests when outarg is also an inarg

* feat: add awkward_ListOffsetArray_drop_none_indexes

* fix: awkward_ListOffsetArray_drop_none_indexes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ManasviGoyal and pre-commit-ci[bot] committed Feb 9, 2024
1 parent 2c780d0 commit b7eb8b9
Show file tree
Hide file tree
Showing 23 changed files with 1,805 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ERROR awkward_ListOffsetArray_toRegularArray(
*size = count;
}
else if (*size != count) {
return failure("cannot convert to RegularArray because subarray lengths are not " "regular", i, kSliceNone, FILENAME(__LINE__));
return failure("cannot convert to RegularArray because subarray lengths are not regular", i, kSliceNone, FILENAME(__LINE__));
}
}
if (*size == -1) {
Expand Down
15 changes: 14 additions & 1 deletion dev/generate-kernel-signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"awkward_ListArray_min_range",
"awkward_ListArray_validity",
"awkward_BitMaskedArray_to_ByteMaskedArray",
"awkward_ListArray_broadcast_tooffsets",
"awkward_ListArray_compact_offsets",
"awkward_ListOffsetArray_flatten_offsets",
"awkward_IndexedArray_overlay_mask",
Expand Down Expand Up @@ -52,14 +53,18 @@
"awkward_RegularArray_reduce_nonlocal_preparenext",
"awkward_missing_repeat",
"awkward_RegularArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
"awkward_ListArray_getitem_next_range_spreadadvanced",
"awkward_ListArray_localindex",
"awkward_NumpyArray_reduce_adjust_starts_64",
"awkward_NumpyArray_reduce_adjust_starts_shifts_64",
"awkward_RegularArray_getitem_next_at",
Expand All @@ -76,14 +81,22 @@
"awkward_IndexedArray_getitem_nextcarry",
"awkward_IndexedArray_getitem_nextcarry_outindex",
"awkward_IndexedArray_index_of_nulls",
"awkward_IndexedArray_ranges_next_64",
"awkward_IndexedArray_ranges_carry_next_64",
"awkward_IndexedArray_reduce_next_64",
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64",
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_fromshifts_64",
"awkward_IndexedOptionArray_rpad_and_clip_mask_axis1",
"awkward_ListOffsetArray_rpad_and_clip_axis1",
"awkward_ListOffsetArray_rpad_length_axis1",
"awkward_ListOffsetArray_toRegularArray",
# "awkward_ListOffsetArray_rpad_axis1",
"awkward_MaskedArray_getitem_next_jagged_project",
"awkward_UnionArray_project",
"awkward_ListOffsetArray_drop_none_indexes",
"awkward_ListOffsetArray_reduce_local_nextparents_64",
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
"awkward_UnionArray_regular_index_getsize",
"awkward_UnionArray_simplify",
"awkward_UnionArray_simplify_one",
"awkward_reduce_argmax",
Expand Down
138 changes: 99 additions & 39 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ def getdtypes(args):
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool" or typename == "float":
if typename == "bool":
typename = typename + "_"
if typename == "float":
typename = typename + "32"
if count == 1:
dtypes.append("cupy." + typename)
elif count == 2:
Expand Down Expand Up @@ -286,8 +288,7 @@ def unittestmap():

def getunittests(test_inputs, test_outputs):
unit_tests = {**test_outputs, **test_inputs}
num_outputs = len(test_outputs)
return unit_tests, num_outputs
return unit_tests


def gettypename(spectype):
Expand Down Expand Up @@ -602,32 +603,52 @@ def gencpuunittests(specdict):
funcName = (
"def test_unit_cpu" + spec.name + "_" + str(num) + "():\n"
)
unit_tests, num_outputs = getunittests(
test["inputs"], test["outputs"]
)
unit_tests = getunittests(test["inputs"], test["outputs"])
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), test["error"], spec.args)
if flag and range:
num += 1
f.write(funcName)
for i, (arg, val) in enumerate(unit_tests.items()):
for arg, val in test["outputs"].items():
typename = remove_const(
next(
argument
for argument in spec.args
if argument.name == arg
).typename
)
if i < num_outputs:
f.write(
" " * 4
+ arg
+ " = "
+ str([gettypeval(typename)] * len(val))
+ "\n"
)
else:
f.write(" " * 4 + arg + " = " + str(val) + "\n")
f.write(
" " * 4
+ arg
+ " = "
+ str([gettypeval(typename)] * len(val))
+ "\n"
)
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if count == 1:
f.write(
" " * 4
+ f"{arg} = (ctypes.c_{typename}*len({arg}))(*{arg})\n"
)
elif count == 2:
f.write(
" " * 4
+ "{0} = ctypes.pointer(ctypes.cast((ctypes.c_{1}*len({0}[0]))(*{0}[0]),ctypes.POINTER(ctypes.c_{1})))\n".format(
arg, typename
)
)
for arg, val in test["inputs"].items():
typename = remove_const(
next(
argument
for argument in spec.args
if argument.name == arg
).typename
)

f.write(" " * 4 + arg + " = " + str(val) + "\n")
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
Expand Down Expand Up @@ -680,6 +701,7 @@ def gencpuunittests(specdict):
"awkward_ListArray_min_range",
"awkward_ListArray_validity",
"awkward_BitMaskedArray_to_ByteMaskedArray",
"awkward_ListArray_broadcast_tooffsets",
"awkward_ListArray_compact_offsets",
"awkward_ListOffsetArray_flatten_offsets",
"awkward_IndexedArray_overlay_mask",
Expand Down Expand Up @@ -716,14 +738,18 @@ def gencpuunittests(specdict):
"awkward_RegularArray_reduce_nonlocal_preparenext",
"awkward_missing_repeat",
"awkward_RegularArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
"awkward_ListArray_getitem_next_range_spreadadvanced",
"awkward_ListArray_localindex",
"awkward_NumpyArray_reduce_adjust_starts_64",
"awkward_NumpyArray_reduce_adjust_starts_shifts_64",
"awkward_RegularArray_getitem_next_at",
Expand All @@ -740,14 +766,22 @@ def gencpuunittests(specdict):
"awkward_IndexedArray_getitem_nextcarry",
"awkward_IndexedArray_getitem_nextcarry_outindex",
"awkward_IndexedArray_index_of_nulls",
"awkward_IndexedArray_ranges_next_64",
"awkward_IndexedArray_ranges_carry_next_64",
"awkward_IndexedArray_reduce_next_64",
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64",
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_fromshifts_64",
"awkward_IndexedOptionArray_rpad_and_clip_mask_axis1",
"awkward_ListOffsetArray_rpad_and_clip_axis1",
"awkward_ListOffsetArray_rpad_length_axis1",
"awkward_ListOffsetArray_toRegularArray",
# "awkward_ListOffsetArray_rpad_axis1",
"awkward_MaskedArray_getitem_next_jagged_project",
"awkward_UnionArray_project",
"awkward_ListOffsetArray_drop_none_indexes",
"awkward_ListOffsetArray_reduce_local_nextparents_64",
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
"awkward_UnionArray_regular_index_getsize",
"awkward_UnionArray_simplify",
"awkward_UnionArray_simplify_one",
"awkward_reduce_argmax",
Expand Down Expand Up @@ -841,8 +875,10 @@ def gencudakerneltests(specdict):
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool" or typename == "float":
if typename == "bool":
typename = typename + "_"
if typename == "float":
typename = typename + "32"
if count == 1:
f.write(
" " * 4
Expand Down Expand Up @@ -957,9 +993,7 @@ def gencudaunittests(specdict):
"def test_unit_cuda" + spec.name + "_" + str(num) + "():\n"
)
dtypes = getdtypes(spec.args)
unit_tests, num_outputs = getunittests(
test["inputs"], test["outputs"]
)
unit_tests = getunittests(test["inputs"], test["outputs"])
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), test["error"], spec.args)
if flag and range:
Expand All @@ -969,7 +1003,7 @@ def gencudaunittests(specdict):
"@pytest.mark.skip(reason='Kernel is not implemented properly')\n"
)
f.write(funcName)
for i, (arg, val) in enumerate(unit_tests.items()):
for arg, val in test["outputs"].items():
typename = remove_const(
next(
argument
Expand All @@ -982,25 +1016,50 @@ def gencudaunittests(specdict):
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool" or typename == "float":
if typename == "bool":
typename = typename + "_"
if typename == "float":
typename = typename + "32"
if count == 1:
if i < num_outputs:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg,
[gettypeval(typename)] * len(val),
typename,
)
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg,
[gettypeval(typename)] * len(val),
typename,
)
else:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg, val, typename
)
)
elif count == 2:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg, val, typename
)
)
for arg, val in test["inputs"].items():
typename = remove_const(
next(
argument
for argument in spec.args
if argument.name == arg
).typename
)
if "List" not in typename:
f.write(" " * 4 + arg + " = " + str(val) + "\n")
if "List" in typename:
count = typename.count("List")
typename = gettypename(typename)
if typename == "bool":
typename = typename + "_"
if typename == "float":
typename = typename + "32"
if count == 1:
f.write(
" " * 4
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
arg, val, typename
)
)
elif count == 2:
f.write(
" " * 4
Expand Down Expand Up @@ -1083,7 +1142,8 @@ def genunittests():
for key in test["outputs"]:
line += key + " = " + key + ","
for key in test["inputs"]:
line += key + " = " + key + ","
if key not in test["outputs"]:
line += key + " = " + key + ","
line = line[0 : len(line) - 1]
line += ")\n"
if test["error"]:
Expand Down
9 changes: 5 additions & 4 deletions kernel-specification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1001,10 +1001,10 @@ kernels:
stride = fromstops[i] - fromstarts[i]
tostarts[i] = k
for j in range(stride):
if index[fromstarts[i] + j] > 0:
if index[fromstarts[i] + j] >= 0:
k = k + 1
tostops[i] = k
tolength = k
tolength[0] = k
automatic-tests: false


Expand Down Expand Up @@ -1040,7 +1040,7 @@ kernels:
for i in range(length):
stride = fromstops[i] - fromstarts[i]
for j in range(stride):
if index[fromstarts[i] + j] > 0:
if index[fromstarts[i] + j] >= 0:
tocarry[k] = index[fromstarts[i] + j]
k = k + 1
automatic-tests: false
Expand Down Expand Up @@ -1515,7 +1515,8 @@ kernels:
if slicestop > missinglength:
raise ValueError("jagged slice's offsets extend beyond its content")
for j in range(slicestart, slicestop):
numvalid[0] = numvalid[0] + 1 if missing[j] >= 0 else 0
if missing[j] >= 0:
numvalid[0] = numvalid[0] + 1
automatic-tests: false

- name: awkward_ListArray_getitem_jagged_shrink
Expand Down
Loading

0 comments on commit b7eb8b9

Please sign in to comment.