Merge pull request #2 from johndoe6345789/codex/ensure-function-raises-exception-for-invalid-types

Extend parameter annotation enforcement to lambdas
This commit is contained in:
2025-12-24 16:59:18 +00:00
committed by GitHub

View File

@@ -97,12 +97,18 @@ validate_keywords(asdl_keyword_seq *keywords)
} }
static int static int
validate_args(asdl_arg_seq *args) validate_args(asdl_arg_seq *args, int require_annotations)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) {
arg_ty arg = asdl_seq_GET(args, i); arg_ty arg = asdl_seq_GET(args, i);
VALIDATE_POSITIONS(arg); VALIDATE_POSITIONS(arg);
if (require_annotations && arg->annotation == NULL) {
PyErr_Format(PyExc_SyntaxError,
"missing type annotation for argument '%U'",
arg->arg);
return 0;
}
if (arg->annotation && !validate_expr(arg->annotation, Load)) if (arg->annotation && !validate_expr(arg->annotation, Load))
return 0; return 0;
} }
@@ -125,21 +131,35 @@ expr_context_name(expr_context_ty ctx)
} }
static int static int
validate_arguments(arguments_ty args) validate_vararg(arg_ty arg, int require_annotations)
{
if (require_annotations && arg->annotation == NULL) {
PyErr_Format(PyExc_SyntaxError,
"missing type annotation for argument '%U'",
arg->arg);
return 0;
}
if (arg->annotation && !validate_expr(arg->annotation, Load)) {
return 0;
}
return 1;
}
static int
validate_arguments(arguments_ty args, int require_annotations)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
if (!validate_args(args->posonlyargs) || !validate_args(args->args)) { if (!validate_args(args->posonlyargs, require_annotations)
|| !validate_args(args->args, require_annotations)) {
return 0; return 0;
} }
if (args->vararg && args->vararg->annotation if (args->vararg && !validate_vararg(args->vararg, require_annotations)) {
&& !validate_expr(args->vararg->annotation, Load)) { return 0;
return 0; }
} if (!validate_args(args->kwonlyargs, require_annotations))
if (!validate_args(args->kwonlyargs)) return 0;
if (args->kwarg && !validate_vararg(args->kwarg, require_annotations)) {
return 0; return 0;
if (args->kwarg && args->kwarg->annotation
&& !validate_expr(args->kwarg->annotation, Load)) {
return 0;
} }
if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) { if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) {
PyErr_SetString(PyExc_ValueError, "more positional defaults than args on arguments"); PyErr_SetString(PyExc_ValueError, "more positional defaults than args on arguments");
@@ -272,7 +292,7 @@ validate_expr(expr_ty exp, expr_context_ty ctx)
ret = validate_expr(exp->v.UnaryOp.operand, Load); ret = validate_expr(exp->v.UnaryOp.operand, Load);
break; break;
case Lambda_kind: case Lambda_kind:
ret = validate_arguments(exp->v.Lambda.args) && ret = validate_arguments(exp->v.Lambda.args, 1) &&
validate_expr(exp->v.Lambda.body, Load); validate_expr(exp->v.Lambda.body, Load);
break; break;
case IfExp_kind: case IfExp_kind:
@@ -734,7 +754,7 @@ validate_stmt(stmt_ty stmt)
case FunctionDef_kind: case FunctionDef_kind:
ret = validate_body(stmt->v.FunctionDef.body, "FunctionDef") && ret = validate_body(stmt->v.FunctionDef.body, "FunctionDef") &&
validate_type_params(stmt->v.FunctionDef.type_params) && validate_type_params(stmt->v.FunctionDef.type_params) &&
validate_arguments(stmt->v.FunctionDef.args) && validate_arguments(stmt->v.FunctionDef.args, 1) &&
validate_exprs(stmt->v.FunctionDef.decorator_list, Load, 0) && validate_exprs(stmt->v.FunctionDef.decorator_list, Load, 0) &&
(!stmt->v.FunctionDef.returns || (!stmt->v.FunctionDef.returns ||
validate_expr(stmt->v.FunctionDef.returns, Load)); validate_expr(stmt->v.FunctionDef.returns, Load));
@@ -930,7 +950,7 @@ validate_stmt(stmt_ty stmt)
case AsyncFunctionDef_kind: case AsyncFunctionDef_kind:
ret = validate_body(stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") && ret = validate_body(stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") &&
validate_type_params(stmt->v.AsyncFunctionDef.type_params) && validate_type_params(stmt->v.AsyncFunctionDef.type_params) &&
validate_arguments(stmt->v.AsyncFunctionDef.args) && validate_arguments(stmt->v.AsyncFunctionDef.args, 1) &&
validate_exprs(stmt->v.AsyncFunctionDef.decorator_list, Load, 0) && validate_exprs(stmt->v.AsyncFunctionDef.decorator_list, Load, 0) &&
(!stmt->v.AsyncFunctionDef.returns || (!stmt->v.AsyncFunctionDef.returns ||
validate_expr(stmt->v.AsyncFunctionDef.returns, Load)); validate_expr(stmt->v.AsyncFunctionDef.returns, Load));