# Copyright 2025 Karlsruhe Institute of Technology
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from importlib import metadata
from flask import current_app
from marshmallow import fields
import kadi.lib.constants as const
from kadi.lib.schemas import CustomString
from kadi.lib.web import parse_url_rule
MODULE_REGEX = re.compile(r"kadi\.modules\.([a-z]+)[^\s]+")
MODULE_MAP = {
"accounts": "users",
"main": "misc",
}
PATH_TYPE_MAP = {
"int": "integer",
}
SCHEMA_TYPE_MAP = {
CustomString: "string",
fields.String: "string",
fields.Integer: "integer",
fields.Boolean: "boolean",
}
[docs]
class OpenAPISpec:
"""Container class for generating OpenAPI specification of Kadi4Mat's HTTP API.
Uses OpenAPI specification version ``3.1.1``.
:param version: The version of the Kadi4Mat HTTP API to generate the specification
for. If the given version is invalid, the default API version as defined in
:const:`kadi.lib.constants.API_VERSION_DEFAULT` is used.
:param app: (optional) The application to generate the specification for. Defaults
to the current application.
"""
def __init__(self, version, app=None):
self.version = (
version if version in const.API_VERSIONS else const.API_VERSION_DEFAULT
)
self.app = app if app is not None else current_app
kadi_version = metadata.version("kadi")
self._routes = []
self._spec = {
"openapi": "3.1.1",
"info": {
"title": f"Kadi4Mat HTTP API {self.version}",
"summary": (
f"The HTTP API {self.version} of the virtual research environment"
" Kadi4Mat."
),
"description": (
f"This specification was generated using Kadi4Mat version"
f" `{kadi_version}` and documents all endpoints corresponding to"
f" Kadi4Mat's HTTP API version `{self.version}`.\n\nFor more"
" information about Kadi4Mat, please see its"
f" [website]({const.URL_INDEX})."
),
"version": kadi_version,
},
"servers": [
{"url": self.app.base_url},
],
"paths": {},
"components": {
"securitySchemes": {
"bearerAuth": {
"type": "http",
"description": "A personal access token or an OAuth2 token.",
"scheme": "bearer",
},
},
},
}
self._collect_routes()
self._populate_paths()
@property
def spec(self):
"""Get the OpenAPI specification as a dictionary."""
return self._spec
def _collect_routes(self):
for rule in self.app.url_map.iter_rules():
endpoint = rule.endpoint
view_func = self.app.view_functions[endpoint]
# Exclude non-API endpoints and endpoints that don't match the given API
# version.
if not endpoint.startswith("api.") or not endpoint.endswith(
f"_{self.version}"
):
continue
# Exclude internal and experimental endpoints.
apispec_meta = getattr(view_func, const.APISPEC_META_ATTR, {})
is_internal = apispec_meta.get(const.APISPEC_INTERNAL_KEY, False)
is_experimental = apispec_meta.get(const.APISPEC_EXPERIMENTAL_KEY, False)
if is_internal or is_experimental:
continue
# Exclude endpoints that are not part of a Kadi4Mat "module".
match = MODULE_REGEX.search(view_func.__module__)
if not match:
continue
module = match.group(1)
module = MODULE_MAP.get(module, module)
method = list(rule.methods.difference({"OPTIONS", "HEAD"}))[0].lower()
route = {
"endpoint": endpoint,
"func": view_func,
"method": method,
"module": module,
"parameters": {},
"path": "",
}
for conv, _, var in parse_url_rule(rule.rule):
if conv:
route["path"] += f"{{{var}}}"
route["parameters"][var] = conv
else:
route["path"] += var
self._routes.append(route)
self._routes.sort(key=lambda route: route["path"])
def _populate_paths(self):
for route in self._routes:
path = route["path"]
view_func = route["func"]
if path not in self.spec["paths"]:
self.spec["paths"][path] = {}
summary, description = self._extract_docstring(view_func)
operation_spec = {
"tags": [route["module"]],
"summary": summary,
"operationId": route["endpoint"],
"responses": {},
}
if description:
operation_spec["description"] = description
apispec_meta = getattr(view_func, const.APISPEC_META_ATTR, {})
self._add_security(operation_spec, apispec_meta)
self._add_responses(operation_spec, apispec_meta)
self._add_parameters(operation_spec, apispec_meta, route)
self._add_request_body(operation_spec, apispec_meta, route)
self.spec["paths"][path][route["method"]] = operation_spec
def _extract_docstring(self, view_func):
sections = view_func.__doc__.split("\n\n")
summary = sections[0]
description = ""
if len(sections) > 1:
for section in sections[1:]:
lines = [line.strip() for line in section.splitlines()]
description += f"{' '.join(lines)}\n\n"
description = description.strip()
return summary, self._prepare_description(description)
def _prepare_description(self, description):
# Replace textual references to other endpoints with the versioned endpoint.
path_prefix = "/api/"
description = description.replace(path_prefix, f"{path_prefix}{self.version}/")
# Replace textual references to module functions with a corresponding
# documentation hyperlink.
def _replace_reference(match):
ref = match.group(0)
return f"[`{ref}()`]({const.URL_RTD_STABLE}/apiref/modules.html#{ref})"
return MODULE_REGEX.sub(_replace_reference, description)
def _add_security(self, operation_spec, apispec_meta):
scopes = apispec_meta.get(const.APISPEC_SCOPES_KEY, [])
operation_spec["security"] = [{"bearerAuth": scopes}]
def _add_responses(self, operation_spec, apispec_meta):
status_meta = apispec_meta.get(const.APISPEC_STATUS_KEY, {})
for status_code, description in status_meta.items():
operation_spec["responses"][status_code] = {
"description": self._prepare_description(description)
}
def _add_parameters(self, operation_spec, apispec_meta, route):
parameters = []
# Add path parameters.
for name, type in route["parameters"].items():
parameters.append(
{
"name": name,
"in": "path",
"required": True,
"schema": {
"type": PATH_TYPE_MAP.get(type, "string"),
},
}
)
# Add header parameters.
reqheaders_meta = apispec_meta.get(const.APISPEC_REQ_HEADERS_KEY, {})
for name, data in reqheaders_meta.items():
header_param = {
"name": name,
"in": "header",
"description": self._prepare_description(data.get("description", "")),
"schema": {
"type": data.get("type", "string"),
},
}
if data.get("required", False):
header_param["required"] = True
parameters.append(header_param)
# Add pagination query parameters.
pagination_meta = apispec_meta.get(const.APISPEC_PAGINATION_KEY, {})
if pagination_meta:
page_param = {
"name": "page",
"in": "query",
"description": "The current result page.",
"schema": {
"type": "integer",
"default": 1,
"minimum": 1,
},
}
page_max = pagination_meta["page_max"]
if page_max:
page_param["schema"]["maximum"] = page_max
per_page_param = {
"name": "per_page",
"in": "query",
"description": "Number of results per page.",
"schema": {
"type": "integer",
"default": 10,
"minimum": 1,
"maximum": pagination_meta["per_page_max"],
},
}
parameters += [page_param, per_page_param]
# Add all other query parameters.
qparam_meta = apispec_meta.get(const.APISPEC_QPARAMS_KEY, {})
for name, data in qparam_meta.items():
qparam = {
"name": name,
"in": "query",
"description": self._prepare_description(data["description"]),
"schema": {
"type": data["type"],
},
}
if data["multiple"]:
qparam["schema"] = {
"type": "array",
"items": qparam["schema"],
}
if choice := data["choice"]:
qparam["schema"]["enum"] = choice
default_value = data["default"]
if default_value is not None and default_value != "":
if not isinstance(default_value, (str, int, bool)):
default_value = str(default_value)
qparam["schema"]["default"] = default_value
parameters.append(qparam)
if parameters:
operation_spec["parameters"] = parameters
def _add_request_body(self, operation_spec, apispec_meta, route):
reqschema_meta = apispec_meta.get(const.APISPEC_REQ_SCHEMA_KEY, {})
if reqschema_meta:
fields = self._get_schema_fields(reqschema_meta["schema"])
schema = self._get_request_schema(fields)
operation_spec["requestBody"] = {
"required": True,
"description": self._prepare_description(reqschema_meta["description"]),
"content": {
const.MIMETYPE_JSON: {"schema": schema},
},
}
elif route["method"] == "put":
# We just assume that a binary upload is required in this case, as there is
# currently no other way to retrieve this information.
operation_spec["requestBody"] = {
"required": True,
"content": {const.MIMETYPE_BINARY: {}},
}
def _get_schema_fields(self, schema, is_partial=False):
def _get_field_attr(field, attr, default=None):
# Try to retrieve the attribute from the custom field metadata first.
if attr in field.metadata:
return field.metadata[attr]
if hasattr(field, attr):
return getattr(field, attr)
return default
schema_fields = {}
for name, field in schema.fields.items():
if field.dump_only:
continue
field_meta = {
"required": _get_field_attr(field, "required", False),
"many": _get_field_attr(field, "many", False),
"type": _get_field_attr(
field, "type", SCHEMA_TYPE_MAP.get(field.__class__, "object")
),
}
is_partial = is_partial or (
schema.partial is True
or (isinstance(schema.partial, tuple) and name in schema.partial)
)
if is_partial:
field_meta["required"] = False
if isinstance(field, fields.Pluck):
field_meta["type"] = SCHEMA_TYPE_MAP.get(
field.schema.fields[field.field_name].__class__, "object"
)
elif isinstance(field, fields.Nested):
# Pass along the information whether the schema is loaded partially, as
# we can't retrieve it from the nested schema directly.
field_meta["nested"] = self._get_schema_fields(
field.schema, is_partial=is_partial
)
schema_fields[name] = field_meta
sorted_fields = sorted(schema_fields.items(), key=lambda field: field[0])
sorted_fields = sorted(
sorted_fields, key=lambda field: field[1]["required"], reverse=True
)
return dict(sorted_fields)
def _get_request_schema(self, fields):
properties = {}
required_fields = []
for name, data in fields.items():
# If the type definition has been provided as a dictionary, take it as-is.
if isinstance(data["type"], dict):
properties[name] = data["type"]
else:
if "nested" in data:
properties[name] = self._get_request_schema(data["nested"])
else:
properties[name] = {"type": data["type"]}
if data["many"]:
properties[name] = {
"type": "array",
"items": {"type": properties[name]["type"]},
}
if data["required"]:
required_fields.append(name)
schema = {
"type": "object",
"properties": properties,
}
if required_fields:
schema["required"] = required_fields
return schema