Skip to content

Commit

Permalink
fix: generate error message and tests for CUDA and CPU kernels (#2989)
Browse files Browse the repository at this point in the history
* fix: generate error messages for CUDA kernels

* tests: add tests for checking error messages

* fix: test for awkward_ListArray_getitem_next_array_advanced

* test: error messge for CPU kernels

* fix: error message in CUDA awkward_ListArray_compact_offsets to match with the CPU one

* fix: error in pytest.raise

* test: add re.escape() in generated tests

---------

Co-authored-by: Jim Pivarski <jpivarski@users.noreply.github.com>
  • Loading branch information
ManasviGoyal and jpivarski committed Jan 30, 2024
1 parent 9191df2 commit f2a2340
Show file tree
Hide file tree
Showing 5 changed files with 2,017 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ERROR awkward_ListArray_rpad_axis1(
}
offset = (target > rangeval) ? tostarts[i] + target : tostarts[i] + rangeval;
tostops[i] = offset;
}
}
return success();
}
ERROR awkward_ListArray32_rpad_axis1_64(
Expand Down
96 changes: 60 additions & 36 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,23 @@ def checkuint(test_args, args):
return flag


def checkintrange(test_args, args):
def checkintrange(test_args, error, args):
flag = True
for arg, val in test_args:
typename = remove_const(
next(argument for argument in args if argument.name == arg).typename
)
if "int" in typename or "uint" in typename:
dtype = gettypename(typename)
min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max
if "List" in typename:
for data in val:
if not (min_val <= data <= max_val):
if not error:
for arg, val in test_args:
typename = remove_const(
next(argument for argument in args if argument.name == arg).typename
)
if "int" in typename or "uint" in typename:
dtype = gettypename(typename)
min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max
if "List" in typename:
for data in val:
if not (min_val <= data <= max_val):
flag = False
else:
if not (min_val <= val <= max_val):
flag = False
else:
if not (min_val <= val <= max_val):
flag = False
return flag


Expand Down Expand Up @@ -581,7 +582,7 @@ def gencpuunittests(specdict):
test["inputs"], test["outputs"]
)
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), test["error"], spec.args)
if flag and range:
num += 1
f.write(funcName)
Expand Down Expand Up @@ -628,17 +629,25 @@ def gencpuunittests(specdict):
count += 1
else:
args += ", " + arg.name
f.write(" " * 4 + "ret_pass = funcC(" + args + ")\n")
for arg, val in test["outputs"].items():
f.write(" " * 4 + "pytest_" + arg + " = " + str(val) + "\n")
if isinstance(val, list):
if not test["error"]:
f.write(" " * 4 + "ret_pass = funcC(" + args + ")\n")
for arg, val in test["outputs"].items():
f.write(
" " * 4
+ f"assert {arg}[:len(pytest_{arg})] == pytest.approx(pytest_{arg})\n"
" " * 4 + "pytest_" + arg + " = " + str(val) + "\n"
)
else:
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
f.write(" " * 4 + "assert not ret_pass.str\n")
if isinstance(val, list):
f.write(
" " * 4
+ f"assert {arg}[:len(pytest_{arg})] == pytest.approx(pytest_{arg})\n"
)
else:
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
f.write(" " * 4 + "assert not ret_pass.str\n")
else:
f.write(
" " * 4
+ f"assert funcC({args}).str.decode('utf-8') == \"{test['message']}\"\n"
)
f.write("\n")


Expand Down Expand Up @@ -896,6 +905,7 @@ def gencudaunittests(specdict):
)

f.write(
"import re\n"
"import cupy\n"
"import pytest\n\n"
"import awkward as ak\n"
Expand All @@ -915,7 +925,7 @@ def gencudaunittests(specdict):
test["inputs"], test["outputs"]
)
flag = checkuint(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), spec.args)
range = checkintrange(unit_tests.items(), test["error"], spec.args)
if flag and range:
num += 1
if not status:
Expand Down Expand Up @@ -973,24 +983,38 @@ def gencudaunittests(specdict):
else:
args += ", " + arg.name
f.write(" " * 4 + "funcC(" + args + ")\n")
f.write(
"""
if test["error"]:
f.write(
f"""
error_message = re.escape("{test['message']} in compiled CUDA code ({spec.templatized_kernel_name})")
"""
)
f.write(
""" with pytest.raises(ValueError, match=rf"{error_message}"):
ak_cu.synchronize_cuda()
"""
)
else:
f.write(
"""
try:
ak_cu.synchronize_cuda()
except:
pytest.fail("This test case shouldn't have raised an error")
"""
)
for arg, val in test["outputs"].items():
f.write(" " * 4 + "pytest_" + arg + " = " + str(val) + "\n")
if isinstance(val, list):
)
for arg, val in test["outputs"].items():
f.write(
" " * 4
+ f"assert cupy.array_equal({arg}[:len(pytest_{arg})], cupy.array(pytest_{arg}))\n"
" " * 4 + "pytest_" + arg + " = " + str(val) + "\n"
)
else:
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
f.write("\n")
if isinstance(val, list):
f.write(
" " * 4
+ f"assert cupy.array_equal({arg}[:len(pytest_{arg})], cupy.array(pytest_{arg}))\n"
)
else:
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
f.write("\n")


def genunittests():
Expand Down
Loading

0 comments on commit f2a2340

Please sign in to comment.