Skip to content

Commit

Permalink
feat: add some misc CUDA kernels (#3141)
Browse files Browse the repository at this point in the history
* feat: add awkward_NumpyArray_subrange_equal and awkward_NumpyArray_subrange_equal_bool kernel

* fix: grid-stride loop

* fix: awkward_ListOffsetArray_rpad_axis1

* feat: add awkward_UnionArray_regular_index.cu

* test: rearrange and add tests
  • Loading branch information
ManasviGoyal committed Jun 7, 2024
1 parent 5de7b35 commit ef1e851
Show file tree
Hide file tree
Showing 16 changed files with 707 additions and 507 deletions.
3 changes: 3 additions & 0 deletions dev/generate-kernel-signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListArray_rpad_axis1",
"awkward_UnionArray_regular_index",
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
"awkward_ListArray_getitem_next_range_spreadadvanced",
"awkward_ListArray_localindex",
Expand All @@ -84,6 +85,8 @@
"awkward_Content_getitem_next_missing_jagged_getmaskstartstop",
"awkward_index_rpad_and_clip_axis0",
"awkward_index_rpad_and_clip_axis1",
"awkward_NumpyArray_subrange_equal",
"awkward_NumpyArray_subrange_equal_bool",
"awkward_IndexedArray_flatten_nextcarry",
"awkward_IndexedArray_flatten_none2empty",
"awkward_IndexedArray_getitem_nextcarry",
Expand Down
3 changes: 3 additions & 0 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ def gencpuunittests(specdict):
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListArray_rpad_axis1",
"awkward_UnionArray_regular_index",
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
"awkward_ListArray_getitem_next_range_spreadadvanced",
"awkward_ListArray_localindex",
Expand All @@ -869,6 +870,8 @@ def gencpuunittests(specdict):
"awkward_Content_getitem_next_missing_jagged_getmaskstartstop",
"awkward_index_rpad_and_clip_axis0",
"awkward_index_rpad_and_clip_axis1",
"awkward_NumpyArray_subrange_equal",
"awkward_NumpyArray_subrange_equal_bool",
"awkward_IndexedArray_flatten_nextcarry",
"awkward_IndexedArray_flatten_none2empty",
"awkward_IndexedArray_getitem_nextcarry",
Expand Down
4 changes: 2 additions & 2 deletions kernel-specification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ kernels:
k = k + 1
automatic-tests: true

- name: awkward_IndexedArray_local_preparenext
- name: awkward_IndexedArray_local_preparenext_64
specializations:
- name: awkward_IndexedArray_local_preparenext_64
args:
Expand All @@ -555,7 +555,7 @@ kernels:
- {name: nextlen, type: "Const[int64_t]", dir: in, role: default}
description: null
definition: |
def awkward_IndexedArray_local_preparenext(
def awkward_IndexedArray_local_preparenext_64(
tocarry, starts, parents, parentslength, nextparents, nextlen
):
j = 0
Expand Down
144 changes: 137 additions & 7 deletions kernel-test-data.json
Original file line number Diff line number Diff line change
Expand Up @@ -10677,7 +10677,7 @@
},
{
"name": "awkward_ListOffsetArray_rpad_axis1",
"status": false,
"status": true,
"tests": [
{
"error": false,
Expand Down Expand Up @@ -15608,7 +15608,7 @@
},
{
"name": "awkward_UnionArray_regular_index",
"status": false,
"status": true,
"tests": [
{
"error": false,
Expand Down Expand Up @@ -16614,7 +16614,7 @@
]
},
{
"name": "awkward_IndexedArray_local_preparenext",
"name": "awkward_IndexedArray_local_preparenext_64",
"status": true,
"tests": [
{
Expand All @@ -16639,7 +16639,7 @@
"nextparents": [0, 0, 0, 0, 1, 1, 1],
"parentslength": 11,
"parents": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
"starts": [0, 2, 5]
"starts": [0, 6]
},
"outputs": {
"tocarry": [0, 1, 2, 3, -1, -1, 4, 5, 6, -1, -1]
Expand Down Expand Up @@ -19436,7 +19436,7 @@
},
{
"name": "awkward_ListOffsetArray_reduce_nonlocal_nextshifts_64",
"status": false,
"status": true,
"tests": [
{
"error": false,
Expand Down Expand Up @@ -26387,12 +26387,38 @@
"tmpptr": [0, 2, 2, 3, 5],
"fromstarts": [0, 2, 3, 3],
"fromstops": [2, 3, 3, 5],
"length": 4
"length": 5
},
"outputs": {
"toequal": [0]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [0, 2, 2, 0, 2],
"fromstarts": [0, 2, 3, 3],
"fromstops": [2, 3, 3, 5],
"length": 5
},
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [0, 0, 0, 0, 0],
"fromstarts": [0, 2, 3, 3],
"fromstops": [2, 3, 3, 5],
"length": 5
},
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
Expand All @@ -26406,6 +26432,32 @@
"toequal": [1]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
"fromstarts": [0, 2, 4, 6, 8, 10],
"fromstops": [2, 4, 6, 8, 10, 12],
"length": 6
},
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 2],
"fromstarts": [0, 2, 4, 6, 8, 10],
"fromstops": [2, 4, 6, 8, 10, 12],
"length": 6
},
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
Expand All @@ -26418,6 +26470,19 @@
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [1, 2, 3, 4, 5, 6],
"fromstarts": [2, 2, 2, 2, 2, 2],
"fromstops": [4, 4, 4, 4, 4, 4],
"length": 6
},
"outputs": {
"toequal": [1]
}
}
]
},
Expand Down Expand Up @@ -26458,12 +26523,38 @@
"tmpptr": [0, 2, 2, 3, 5],
"fromstarts": [0, 2, 3, 3],
"fromstops": [2, 3, 3, 5],
"length": 4
"length": 5
},
"outputs": {
"toequal": [0]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [0, 2, 2, 0, 2],
"fromstarts": [0, 2, 3, 3],
"fromstops": [2, 3, 3, 5],
"length": 5
},
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [0, 0, 0, 0, 0],
"fromstarts": [0, 2, 3, 3],
"fromstops": [2, 3, 3, 5],
"length": 5
},
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
Expand All @@ -26477,6 +26568,32 @@
"toequal": [1]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
"fromstarts": [0, 2, 4, 6, 8, 10],
"fromstops": [2, 4, 6, 8, 10, 12],
"length": 6
},
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 2],
"fromstarts": [0, 2, 4, 6, 8, 10],
"fromstops": [2, 4, 6, 8, 10, 12],
"length": 6
},
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
Expand All @@ -26489,6 +26606,19 @@
"outputs": {
"toequal": [1]
}
},
{
"error": false,
"message": "",
"inputs": {
"tmpptr": [1, 2, 3, 4, 5, 6],
"fromstarts": [2, 2, 2, 2, 2, 2],
"fromstops": [4, 4, 4, 4, 4, 4],
"length": 6
},
"outputs": {
"toequal": [1]
}
}
]
},
Expand Down
14 changes: 8 additions & 6 deletions src/awkward/_connect/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def fetch_template_specializations(kernel_dict):
"awkward_IndexedArray_flatten_none2empty",
"awkward_IndexedArray_getitem_nextcarry",
"awkward_IndexedArray_getitem_nextcarry_outindex",
"awkward_ListArray_getitem_next_range_counts",
"awkward_IndexedArray_index_of_nulls",
"awkward_IndexedArray_ranges_next_64",
"awkward_IndexedArray_ranges_carry_next_64",
Expand All @@ -103,19 +102,18 @@ def fetch_template_specializations(kernel_dict):
"awkward_ListArray_getitem_jagged_shrink",
"awkward_ListArray_getitem_next_range",
"awkward_ListArray_getitem_next_range_carrylength",
"awkward_ListArray_getitem_next_range_counts",
"awkward_ListArray_min_range",
"awkward_ListArray_rpad_and_clip_length_axis1",
"awkward_ListArray_rpad_axis1",
"awkward_ListOffsetArray_drop_none_indexes",
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
"awkward_UnionArray_regular_index",
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
"awkward_ListOffsetArray_rpad_axis1",
"awkward_ListOffsetArray_rpad_length_axis1",
"awkward_MaskedArray_getitem_next_jagged_project",
"awkward_UnionArray_nestedfill_tags_index",
"awkward_NumpyArray_rearrange_shifted",
"awkward_UnionArray_flatten_length",
"awkward_UnionArray_flatten_combine",
"awkward_UnionArray_project",
"awkward_reduce_count_64",
"awkward_reduce_sum",
"awkward_reduce_sum_int32_bool_64",
Expand All @@ -129,6 +127,10 @@ def fetch_template_specializations(kernel_dict):
"awkward_reduce_min",
"awkward_sorting_ranges",
"awkward_sorting_ranges_length",
"awkward_UnionArray_flatten_length",
"awkward_UnionArray_flatten_combine",
"awkward_UnionArray_nestedfill_tags_index",
"awkward_UnionArray_project",
]
template_specializations = []
import re
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ awkward_IndexedArray_fill(
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id < length) {
C fromval = fromindex[thread_id];
toindex[toindexoffset + thread_id] = fromval < 0 ? -1 : (C)(fromval + base);
toindex[toindexoffset + thread_id] = fromval < 0 ? (C)-1 : (C)(fromval + base);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ awkward_IndexedArray_getitem_nextcarry_outindex_b(
RAISE_ERROR(
INDEXEDARRAY_GETITEM_NEXTCARRY_OUTINDEX_ERRORS::IND_OUT_OF_RANGE)
} else if (j < 0) {
toindex[thread_id] = -1;
toindex[thread_id] = (C)-1;
} else {
tocarry[scan_in_array[thread_id] - 1] = j;
toindex[thread_id] = (C)(scan_in_array[thread_id] - 1);
Expand Down
Loading

0 comments on commit ef1e851

Please sign in to comment.