Merge pull request #2 from johndoe6345789/codex/ensure-function-raises-exception-for-invalid-types

Extend parameter annotation enforcement to lambdas
This commit is contained in:
2025-12-24 16:59:18 +00:00
committed by GitHub

View File

@@ -97,12 +97,18 @@ validate_keywords(asdl_keyword_seq *keywords)
}
static int
validate_args(asdl_arg_seq *args)
validate_args(asdl_arg_seq *args, int require_annotations)
{
assert(!PyErr_Occurred());
for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) {
arg_ty arg = asdl_seq_GET(args, i);
VALIDATE_POSITIONS(arg);
if (require_annotations && arg->annotation == NULL) {
PyErr_Format(PyExc_SyntaxError,
"missing type annotation for argument '%U'",
arg->arg);
return 0;
}
if (arg->annotation && !validate_expr(arg->annotation, Load))
return 0;
}
@@ -125,21 +131,35 @@ expr_context_name(expr_context_ty ctx)
}
static int
validate_arguments(arguments_ty args)
validate_vararg(arg_ty arg, int require_annotations)
{
if (require_annotations && arg->annotation == NULL) {
PyErr_Format(PyExc_SyntaxError,
"missing type annotation for argument '%U'",
arg->arg);
return 0;
}
if (arg->annotation && !validate_expr(arg->annotation, Load)) {
return 0;
}
return 1;
}
static int
validate_arguments(arguments_ty args, int require_annotations)
{
assert(!PyErr_Occurred());
if (!validate_args(args->posonlyargs) || !validate_args(args->args)) {
if (!validate_args(args->posonlyargs, require_annotations)
|| !validate_args(args->args, require_annotations)) {
return 0;
}
if (args->vararg && args->vararg->annotation
&& !validate_expr(args->vararg->annotation, Load)) {
return 0;
}
if (!validate_args(args->kwonlyargs))
if (args->vararg && !validate_vararg(args->vararg, require_annotations)) {
return 0;
}
if (!validate_args(args->kwonlyargs, require_annotations))
return 0;
if (args->kwarg && !validate_vararg(args->kwarg, require_annotations)) {
return 0;
if (args->kwarg && args->kwarg->annotation
&& !validate_expr(args->kwarg->annotation, Load)) {
return 0;
}
if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) {
PyErr_SetString(PyExc_ValueError, "more positional defaults than args on arguments");
@@ -272,7 +292,7 @@ validate_expr(expr_ty exp, expr_context_ty ctx)
ret = validate_expr(exp->v.UnaryOp.operand, Load);
break;
case Lambda_kind:
ret = validate_arguments(exp->v.Lambda.args) &&
ret = validate_arguments(exp->v.Lambda.args, 1) &&
validate_expr(exp->v.Lambda.body, Load);
break;
case IfExp_kind:
@@ -734,7 +754,7 @@ validate_stmt(stmt_ty stmt)
case FunctionDef_kind:
ret = validate_body(stmt->v.FunctionDef.body, "FunctionDef") &&
validate_type_params(stmt->v.FunctionDef.type_params) &&
validate_arguments(stmt->v.FunctionDef.args) &&
validate_arguments(stmt->v.FunctionDef.args, 1) &&
validate_exprs(stmt->v.FunctionDef.decorator_list, Load, 0) &&
(!stmt->v.FunctionDef.returns ||
validate_expr(stmt->v.FunctionDef.returns, Load));
@@ -930,7 +950,7 @@ validate_stmt(stmt_ty stmt)
case AsyncFunctionDef_kind:
ret = validate_body(stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") &&
validate_type_params(stmt->v.AsyncFunctionDef.type_params) &&
validate_arguments(stmt->v.AsyncFunctionDef.args) &&
validate_arguments(stmt->v.AsyncFunctionDef.args, 1) &&
validate_exprs(stmt->v.AsyncFunctionDef.decorator_list, Load, 0) &&
(!stmt->v.AsyncFunctionDef.returns ||
validate_expr(stmt->v.AsyncFunctionDef.returns, Load));