2
0

Compare commits

..

No commits in common. "10586541b506ec4a79d4a2bf6c9f80a88b60c2a4" and "24025686d48e9e4d391bb7a952607e4564caf1e2" have entirely different histories.

7 changed files with 32 additions and 804 deletions

View File

@ -93,6 +93,7 @@ class Env:
COLOR_YELLOW: str = '' COLOR_YELLOW: str = ''
COLOR_RED: str = '' COLOR_RED: str = ''
COLOR_CYAN: str = '' COLOR_CYAN: str = ''
fcgi_cgi: str = ''
dir: str = '' dir: str = ''
defaultwww: str = '' defaultwww: str = ''
log: io.TextIOBase = dataclasses.field(default_factory=io.StringIO) log: io.TextIOBase = dataclasses.field(default_factory=io.StringIO)
@ -182,6 +183,8 @@ class TestBase(metaclass=_TestBaseMeta):
no_docroot = False no_docroot = False
vhostdir: str = '' vhostdir: str = ''
_active: bool # false if feature_check failed (or failed on parent)
@staticmethod @staticmethod
def _vhostname(testname: str) -> str: def _vhostname(testname: str) -> str:
return '.'.join(reversed(testname[1:-1].split('/'))).lower() return '.'.join(reversed(testname[1:-1].split('/'))).lower()
@ -213,7 +216,11 @@ class TestBase(metaclass=_TestBaseMeta):
self.vhostdir = self._parent.vhostdir self.vhostdir = self._parent.vhostdir
else: else:
self.vhostdir = os.path.join(self.tests.env.dir, 'www', 'vhosts', self.vhost) self.vhostdir = os.path.join(self.tests.env.dir, 'www', 'vhosts', self.vhost)
tests._add_test(self) if self.feature_check() and (not parent or parent._active):
self._active = True
tests._add_test(self)
else:
self._active = False
# internal methods, do not override # internal methods, do not override
def _prepare(self) -> None: def _prepare(self) -> None:
@ -335,6 +342,12 @@ class TestBase(metaclass=_TestBaseMeta):
def cleanup_test(self) -> None: def cleanup_test(self) -> None:
pass pass
# return False is some dependency is missing
# TODO: should we even allow for that? better to fail,
# at least by default.
def feature_check(self) -> bool:
return True
def class2testname(name) -> str: def class2testname(name) -> str:
if name.startswith("Test"): if name.startswith("Test"):

View File

@ -1,662 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
import dataclasses
import enum
import struct
import typing
class ProtocolException(Exception):
pass
class Type(enum.IntEnum):
BEGIN_REQUEST = 1 # web server -> backend
ABORT_REQUEST = 2 # web server -> backend
END_REQUEST = 3 # backend -> web server (status)
PARAMS = 4 # web server -> backend (stream name-value pairs)
STDIN = 5 # web server -> backend (stream request body)
STDOUT = 6 # backend -> web server (stream response body)
STDERR = 7 # backend -> web server (stream error messages)
DATA = 8 # web server -> backend (stream additional data)
GET_VALUES = 9 # web server -> backend (request names-value pairs with empty values)
GET_VALUES_RESULT = 10 # backend -> web server (response name-value pairs)
UNKNOWN_TYPE = 11
SERVER_TYPES = {
Type.BEGIN_REQUEST,
Type.ABORT_REQUEST,
Type.PARAMS,
Type.STDIN,
Type.DATA,
Type.GET_VALUES,
}
CLIENT_TYPES = {
Type.END_REQUEST,
Type.STDOUT,
Type.STDERR,
Type.GET_VALUES_RESULT,
Type.UNKNOWN_TYPE,
}
_NV_PAIR_LONG_LENGTH_BIT = 2**31
class Role(enum.IntEnum):
RESPONDER = 1
AUTHORIZER = 2
FILTER = 3
class ProtocolStatus(enum.IntEnum):
REQUEST_COMPLETE = 0
CANT_MPX_CONN = 1
OVERLOADED = 2
UNKNOWN_ROLE = 3
@dataclasses.dataclass
class Packet:
data: bytes
pckt_type: int
req_id: int
def unpack_server_message(self) -> ServerMessage:
unpack = _UNPACK_SERVER_MESSAGE.get(self.pckt_type, None)
if not unpack:
raise ProtocolException(f'Message type not allowed from server {self.pckt_type}')
return unpack(self.req_id, self.data)
def _pair_encode_len(data: bytes) -> bytes:
data_len = len(data)
if data_len >= 128:
assert data_len < _NV_PAIR_LONG_LENGTH_BIT
return struct.pack('>I', data_len | _NV_PAIR_LONG_LENGTH_BIT)
else:
return struct.pack('>B', data_len)
def _pair_decode_len(payload: bytes) -> typing.Tuple[int, bytes]:
if len(payload) == 0:
raise ProtocolException('Unexpected end of data; looking for pair name/value length')
data_len = payload[0]
if data_len >= 128:
if len(payload) < 4:
raise ProtocolException('Unexpected end of data; looking for pair name/value length')
data_len, = struct.unpack('>I', payload[:4])
return (data_len & ~_NV_PAIR_LONG_LENGTH_BIT, payload[4:])
else:
return (data_len, payload[1:])
def pack_name_value_pairs(
pairs: typing.Iterable[typing.Tuple[bytes, bytes]],
) -> bytes:
payload = b''
for name, value in pairs:
payload += _pair_encode_len(name) + _pair_encode_len(value) + name + value
return payload
def unpack_name_value_pairs(
payload: bytes,
) -> typing.Generator[typing.Tuple[bytes, bytes], None, None]:
while payload:
(name_len, payload) = _pair_decode_len(payload)
(value_len, payload) = _pair_decode_len(payload)
if len(payload) < name_len:
raise ProtocolException('Unexpected end of data; looking for pair name data')
name, payload = (payload[:name_len], payload[name_len:])
if len(payload) < value_len:
raise ProtocolException('Unexpected end of data; looking for pair value data')
value, payload = (payload[:value_len], payload[value_len:])
yield name, value
@dataclasses.dataclass
class MsgBeginRequest:
req_id: int # must be > 0
role: Role = Role.RESPONDER
keep_alive: bool = False
def pack(self) -> Packet:
if self.req_id == 0:
raise ProtocolException(f'invalid BEGIN_REQUEST request id {self.req_id}')
flags = 0
if self.keep_alive:
flags |= 0x01
payload = struct.pack('>HBxxxxx', self.role.value, flags)
return Packet(data=payload, pckt_type=Type.BEGIN_REQUEST.value, req_id=self.req_id)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgBeginRequest:
if len(payload) != 8:
raise ProtocolException(f'invalid BEGIN_REQUEST payload {payload!r}')
if req_id == 0:
raise ProtocolException(f'invalid BEGIN_REQUEST request id {req_id}')
role_num, flags = struct.unpack('>HBxxxxx', payload)
try:
role = Role(role_num)
except ValueError:
raise ProtocolException(f'invalid BEGIN_REQUEST invalid role {role_num}')
keep_alive = 0 != (flags & 1)
return MsgBeginRequest(req_id=req_id, role=role, keep_alive=keep_alive)
@dataclasses.dataclass
class MsgAbortRequest:
req_id: int # must be > 0
def pack(self) -> Packet:
if self.req_id == 0:
raise ProtocolException(f'invalid ABORT_REQUEST request id {self.req_id}')
return Packet(data=b'', pckt_type=Type.ABORT_REQUEST.value, req_id=self.req_id)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgAbortRequest:
if payload:
raise ProtocolException('non-empty ABORT_REQUEST')
return MsgAbortRequest(req_id=req_id)
@dataclasses.dataclass
class MsgEndRequest:
req_id: int # must be > 0
app_status: int = 0
protocol_status: ProtocolStatus = ProtocolStatus.REQUEST_COMPLETE
def pack(self) -> Packet:
if self.req_id == 0:
raise ProtocolException(f'invalid END_REQUEST request id {self.req_id}')
payload = struct.pack('>IBxxx', self.app_status, self.protocol_status.value)
return Packet(data=payload, pckt_type=Type.END_REQUEST.value, req_id=self.req_id)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgEndRequest:
if len(payload) != 8:
raise ProtocolException(f'invalid END_REQUEST payload {payload!r}')
if req_id == 0:
raise ProtocolException(f'invalid END_REQUEST request id {req_id}')
app_status, protocol_status_num = struct.unpack('>IBxxx', payload)
try:
protocol_status = ProtocolStatus(protocol_status_num)
except ValueError:
raise ProtocolException(f'invalid END_REQUEST invalid protocol status {protocol_status_num}')
return MsgEndRequest(req_id=req_id, app_status=app_status, protocol_status=protocol_status)
@dataclasses.dataclass
class _MsgStream:
req_id: int # must be > 0
payload: bytes
class MsgParams(_MsgStream):
def pack(self) -> Packet:
if self.req_id == 0:
raise ProtocolException(f'invalid PARAMS request id {self.req_id}')
return Packet(data=self.payload, pckt_type=Type.PARAMS.value, req_id=self.req_id)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgParams:
if req_id == 0:
raise ProtocolException(f'invalid PARAMS request id {req_id}')
return MsgParams(req_id=req_id, payload=payload)
class MsgStdin(_MsgStream):
def pack(self) -> Packet:
if self.req_id == 0:
raise ProtocolException(f'invalid STDIN request id {self.req_id}')
return Packet(data=self.payload, pckt_type=Type.STDIN.value, req_id=self.req_id)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgStdin:
if req_id == 0:
raise ProtocolException(f'invalid STDIN request id {req_id}')
return MsgStdin(req_id=req_id, payload=payload)
class MsgStdout(_MsgStream):
def pack(self) -> Packet:
if self.req_id == 0:
raise ProtocolException(f'invalid STDOUT request id {self.req_id}')
return Packet(data=self.payload, pckt_type=Type.STDOUT.value, req_id=self.req_id)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgStdout:
if req_id == 0:
raise ProtocolException(f'invalid STDOUT request id {req_id}')
return MsgStdout(req_id=req_id, payload=payload)
class MsgStderr(_MsgStream):
def pack(self) -> Packet:
if self.req_id == 0:
raise ProtocolException(f'invalid STDERR request id {self.req_id}')
return Packet(data=self.payload, pckt_type=Type.STDERR.value, req_id=self.req_id)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgStderr:
if req_id == 0:
raise ProtocolException(f'invalid STDERR request id {req_id}')
return MsgStderr(req_id=req_id, payload=payload)
class MsgData(_MsgStream):
def pack(self) -> Packet:
if self.req_id == 0:
raise ProtocolException(f'invalid DATA request id {self.req_id}')
return Packet(data=self.payload, pckt_type=Type.DATA.value, req_id=self.req_id)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgData:
if req_id == 0:
raise ProtocolException(f'invalid DATA request id {req_id}')
return MsgData(req_id=req_id, payload=payload)
@dataclasses.dataclass
class MsgGetValues:
names: set[bytes]
def pack(self) -> Packet:
payload = pack_name_value_pairs((name, b'') for name in self.names)
return Packet(data=payload, pckt_type=Type.GET_VALUES.value, req_id=0)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgGetValues:
if req_id != 0:
raise ProtocolException(f'invalid GET_VALUES request id {req_id}')
names: set[bytes] = set()
for name, value in unpack_name_value_pairs(payload):
if value:
raise ProtocolException(f'non-empty value {value!r} in GET_VALUES pair for name {name!r}')
names.add(name)
return MsgGetValues(names=names)
@dataclasses.dataclass
class MsgGetValuesResult:
values: dict[bytes, bytes]
def pack(self) -> Packet:
payload = pack_name_value_pairs(self.values.items())
return Packet(data=payload, pckt_type=Type.GET_VALUES_RESULT.value, req_id=0)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgGetValuesResult:
if req_id != 0:
raise ProtocolException(f'invalid GET_VALUES_RESULT request id {req_id}')
values = dict(unpack_name_value_pairs(payload))
return MsgGetValuesResult(values=values)
@dataclasses.dataclass
class MsgUnknownType:
pckt_type: int
def pack(self) -> Packet:
payload = struct.pack('>Bxxxxxxx', self.pckt_type)
return Packet(data=payload, pckt_type=Type.UNKNOWN_TYPE.value, req_id=0)
@staticmethod
def unpack(req_id: int, payload: bytes) -> MsgUnknownType:
if len(payload) != 8:
raise ProtocolException(f'invalid UNKNOWN_TYPE payload {payload!r}')
if req_id != 0:
raise ProtocolException(f'invalid UNKNOWN_TYPE request id {req_id}')
pckt_type, = struct.unpack('>Bxxxxxxx', payload)
return MsgUnknownType(pckt_type=pckt_type)
ServerRequestMessage = typing.Union[
MsgBeginRequest,
MsgAbortRequest,
MsgParams,
MsgStdin,
MsgData,
]
ServerManagementMessage = MsgGetValues # just one for now
ServerMessage = typing.Union[ServerRequestMessage, ServerManagementMessage]
ClientRequestMessage = typing.Union[
MsgEndRequest,
MsgStdout,
MsgStderr,
]
ClientManagementMessage = typing.Union[
MsgGetValuesResult,
MsgUnknownType,
]
ClientMessage = typing.Union[ClientRequestMessage, ClientManagementMessage]
_UNPACK_SERVER_MESSAGE: dict[int, typing.Callable[[int, bytes], ServerMessage]] = {
Type.BEGIN_REQUEST.value: MsgBeginRequest.unpack,
Type.ABORT_REQUEST.value: MsgAbortRequest.unpack,
Type.PARAMS.value: MsgParams.unpack,
Type.STDIN.value: MsgStdin.unpack,
Type.DATA.value: MsgData.unpack,
Type.GET_VALUES.value: MsgGetValues.unpack,
}
_UNPACK_CLIENT_MESSAGE: dict[int, typing.Callable[[int, bytes], ClientMessage]] = {
Type.END_REQUEST.value: MsgEndRequest.unpack,
Type.STDOUT.value: MsgStdout.unpack,
Type.STDERR.value: MsgStderr.unpack,
Type.GET_VALUES_RESULT.value: MsgGetValuesResult.unpack,
Type.UNKNOWN_TYPE.value: MsgUnknownType.unpack,
}
class LowLevelConnection:
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
self._reader = reader
self._writer = writer
self._closed = False
self._write_lock = asyncio.Lock()
def close(self) -> None:
if not self._closed:
self._closed = True
self._writer.close()
def _abort(self, msg: str) -> typing.NoReturn:
self.close()
raise ProtocolException(msg)
def _error_close(self) -> typing.NoReturn:
self.close()
raise EOFError("FastCGI connection is closed")
async def read(self) -> Packet:
if self._closed:
self._error_close()
try:
hdr = await self._reader.readexactly(8)
except asyncio.IncompleteReadError as e:
if e.partial:
self._abort("Incomplete header")
self._error_close()
except (EOFError, ConnectionError):
self._error_close()
version, pckt_type, req_id, data_len, pad_len = struct.unpack('>BBHHBx', hdr)
if version != 1:
self._abort(f"Invalid packet version {version}")
if pad_len >= 8 or (data_len + pad_len) % 8 != 0:
self._abort(f"Invalid packet pad_len (data_len = {data_len}, pad_len = {pad_len})")
total = data_len + pad_len
try:
data = await self._reader.readexactly(total)
except EOFError:
self._abort("Incomplete data")
return Packet(data=data[:data_len], pckt_type=pckt_type, req_id=req_id)
async def write(self, packet: Packet):
async with self._write_lock:
if self._closed:
self._error_close()
version = 1
data_len = len(packet.data)
pad_len = (8 - (data_len % 8)) % 8
padding = b' '[:pad_len]
hdr = struct.pack('>BBHHBx', version, packet.pckt_type, packet.req_id, data_len, pad_len)
self._writer.writelines((hdr, packet.data, padding))
try:
await self._writer.drain()
except ConnectionError as e:
self._abort(f"Connection reset: {e}")
ApplicationRequestT = typing.TypeVar('ApplicationRequestT', bound='ApplicationRequest')
class Application(typing.Generic[ApplicationRequestT]):
def __init__(
self,
*,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
request_cls: type[ApplicationRequestT],
) -> None:
self._lower = LowLevelConnection(reader=reader, writer=writer)
self._request_cls = request_cls
self._requests: dict[int, ApplicationRequestT] = {}
self._msg_handlers = {
Type.BEGIN_REQUEST.value: self._handle_begin_request,
Type.ABORT_REQUEST.value: self._handle_abort_request,
Type.PARAMS.value: self._handle_params,
Type.STDIN.value: self._handle_stdin,
Type.DATA.value: self._handle_data,
Type.GET_VALUES.value: self._handle_get_values,
}
self._values: dict[bytes, bytes] = {}
def _close(self) -> None:
self._lower.close()
requests = list(self._requests.values())
self._requests.clear()
for req in requests:
req._app = None
for req in requests:
req._recv_task.cancel()
async def _handle_begin_request(self, packet: Packet) -> None:
msg = MsgBeginRequest.unpack(packet.req_id, packet.data)
if msg.req_id in self._requests:
raise ProtocolException(f'Request id {msg.req_id} still in use')
self._requests[msg.req_id] = req = self._request_cls(
req_id=msg.req_id,
keep_alive=msg.keep_alive,
)
req._app = self
async def _handle_abort_request(self, packet: Packet) -> None:
msg = MsgAbortRequest.unpack(packet.req_id, packet.data)
req = self._requests.get(msg.req_id, None)
if req:
# protocol spec: ignore messages for unknown request ids
await req._recv_queue.put(msg)
async def _handle_params(self, packet: Packet) -> None:
msg = MsgParams.unpack(packet.req_id, packet.data)
req = self._requests.get(msg.req_id, None)
if req:
# protocol spec: ignore messages for unknown request ids
await req._recv_queue.put(msg)
async def _handle_stdin(self, packet: Packet) -> None:
msg = MsgStdin.unpack(packet.req_id, packet.data)
req = self._requests.get(msg.req_id, None)
if req:
# protocol spec: ignore messages for unknown request ids
await req._recv_queue.put(msg)
async def _handle_data(self, packet: Packet) -> None:
msg = MsgData.unpack(packet.req_id, packet.data)
req = self._requests.get(msg.req_id, None)
if req:
# protocol spec: ignore messages for unknown request ids
await req._recv_queue.put(msg)
async def _handle_get_values(self, packet: Packet) -> None:
msg = MsgGetValues.unpack(packet.req_id, packet.data)
values: dict[bytes, bytes] = {}
for name in msg.names:
value = self._values.get(name, None)
if not value is None:
values[name] = value
await self._write(MsgGetValuesResult(values=values))
async def _write(self, msg: ClientMessage) -> None:
await self._lower.write(msg.pack())
async def run(self) -> None:
try:
while True:
try:
packet = await self._lower.read()
except (ProtocolException, EOFError):
return
try:
Type(packet.pckt_type)
except ValueError:
await self._write(MsgUnknownType(pckt_type=packet.pckt_type))
continue
handler = self._msg_handlers.get(packet.pckt_type, None)
if not handler:
return
await handler(packet)
finally:
self._close()
class ApplicationRequest:
def __init__(self, *, req_id: int, keep_alive: bool) -> None:
self._app: typing.Optional[Application] = None
self.req_id = req_id
self.keep_alive = keep_alive
self.params: dict[bytes, bytes] = {}
self._params_buf = b''
self.params_complete = False
self.stdin_closed = False
self.data_closed = False
self.stdout_closed = False
self.stderr_closed = False
self.aborted = False
self.completed = False
self._recv_queue: asyncio.Queue[typing.Union[
MsgAbortRequest,
MsgParams,
MsgStdin,
MsgData,
]] = asyncio.Queue(maxsize=16)
self._recv_task = asyncio.Task(self._start())
async def _start(self) -> None:
try:
await self.run()
finally:
if not self.completed:
await self._handle_abort()
async def run(self) -> None:
while True:
item = await self._recv_queue.get()
if isinstance(item, MsgAbortRequest):
await self._handle_abort()
elif isinstance(item, MsgParams):
await self._recv_params_stream(item.payload)
elif isinstance(item, MsgStdin):
await self._recv_stdin_stream(item.payload)
# elif isinstance(item, MsgData):
else:
await self._recv_data_stream(item.payload)
async def exit(
self,
*,
status: int = 0,
protocol_status: ProtocolStatus = ProtocolStatus.REQUEST_COMPLETE,
) -> None:
self.stdout_closed = self.stderr_closed = True
if self._app:
app = self._app
self._app = None
del app._requests[self.req_id]
await app._write(
MsgEndRequest(req_id=self.req_id, app_status=status, protocol_status=protocol_status),
)
self._recv_task.cancel()
while True:
try:
self._recv_queue.get_nowait()
except asyncio.QueueEmpty:
break
async def finish(
self,
*,
status: int = 0,
protocol_status: ProtocolStatus = ProtocolStatus.REQUEST_COMPLETE,
) -> None:
await self.write_stdout(b'')
await self.write_stderr(b'')
self.completed = True
await self.exit(status=status, protocol_status=protocol_status)
async def write_stdout(self, data: bytes, *, no_close: bool = False):
if self.stdout_closed:
if len(data) != 0:
raise Exception('stdout already closed')
if len(data) == 0:
if not no_close:
# sending empty data signals EOF by default
self.stdout_closed = True
else:
return # don't send the EOF message!
if self._app:
await self._app._write(MsgStdout(req_id=self.req_id, payload=data))
async def write_stderr(self, data: bytes, *, no_close: bool = False):
if self.stderr_closed:
if len(data) != 0:
raise Exception('stderr already closed')
if len(data) == 0:
if not no_close:
# sending empty data signals EOF by default
self.stderr_closed = True
else:
return # don't send the EOF message!
if self._app:
await self._app._write(MsgStderr(req_id=self.req_id, payload=data))
async def _handle_abort(self) -> None:
if self.completed:
return
if not self.aborted:
self.aborted = True
await self.handle_abort()
async def handle_abort(self) -> None:
await self.exit(status=-1)
async def _recv_params_stream(self, data: bytes) -> None:
if self.params_complete:
raise ProtocolException('Params already closed')
if len(data) == 0:
self.params_complete = True
self.params.update(unpack_name_value_pairs(self._params_buf))
self._params_buf = b''
await self.received_params()
else:
self._params_buf += data
async def received_params(self) -> None:
pass
async def _recv_stdin_stream(self, data: bytes) -> None:
if not self.params_complete:
raise ProtocolException('Received stdin data before params finished')
if len(data) == 0:
self._stdin_closed = True
await self.recv_stdin(data)
async def recv_stdin(self, data: bytes) -> None:
# ignored by default
pass
async def _recv_data_stream(self, data: bytes) -> None:
if len(data) == 0:
self._data_closed = True
await self.recv_data(data)
async def recv_data(self, data: bytes) -> None:
# ignored by default
pass

View File

@ -173,6 +173,8 @@ def setup_env() -> Env:
env.COLOR_RED = env.color and "\033[1;31m" or "" env.COLOR_RED = env.color and "\033[1;31m" or ""
env.COLOR_CYAN = env.color and "\033[1;36m" or "" env.COLOR_CYAN = env.color and "\033[1;36m" or ""
env.fcgi_cgi = which('fcgi-cgi') or ''
env.dir = mkdtemp(suffix='-l2-tests') env.dir = mkdtemp(suffix='-l2-tests')
env.defaultwww = os.path.join(env.dir, "www", "default") env.defaultwww = os.path.join(env.dir, "www", "default")

View File

@ -1,13 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from io import BytesIO from io import BytesIO
import hashlib
import os
import time import time
import hashlib
import pycurl import pycurl
from pylt.base import ModuleTest, Tests from pylt.base import ModuleTest
from pylt.requests import CurlRequest from pylt.requests import CurlRequest
from pylt.service import FastCGI from pylt.service import FastCGI
@ -24,9 +22,9 @@ def generate_body(seed: str, size: int) -> bytes:
class CGI(FastCGI): class CGI(FastCGI):
name = "fcgi_cgi" name = "fcgi_cgi"
def __init__(self, *, tests: Tests) -> None: def prepare_service(self) -> None:
super().__init__(tests=tests) self.binary = [self.tests.env.fcgi_cgi]
self.binary = [os.path.join(self.tests.env.sourcedir, "tests", "run-fcgi-cgi.py")] super().prepare_service()
SCRIPT_ENVCHECK = r"""#!/bin/sh SCRIPT_ENVCHECK = r"""#!/bin/sh
@ -190,8 +188,9 @@ if phys.exists and phys.path =$ ".cgi" {
""" """
def __init__(self, *, tests: Tests) -> None: def feature_check(self) -> bool:
super().__init__(tests=tests) if self.tests.env.fcgi_cgi is None:
return self.MissingFeature('fcgi-cgi')
cgi = CGI(tests=self.tests) cgi = CGI(tests=self.tests)
self.plain_config = f""" self.plain_config = f"""
setup {{ module_load "mod_fastcgi"; }} setup {{ module_load "mod_fastcgi"; }}
@ -200,7 +199,9 @@ cgi = {{
fastcgi "unix:{cgi.sockfile}"; fastcgi "unix:{cgi.sockfile}";
}}; }};
""" """
self.tests.add_service(cgi) self.tests.add_service(cgi)
return True
def prepare_test(self) -> None: def prepare_test(self) -> None:
self.prepare_vhost_file("envcheck.cgi", SCRIPT_ENVCHECK, mode=0o755) self.prepare_vhost_file("envcheck.cgi", SCRIPT_ENVCHECK, mode=0o755)

View File

@ -60,9 +60,7 @@ class Test(base.ModuleTest):
memcache; memcache;
""" """
def __init__(self, *, tests: base.Tests) -> None: def feature_check(self) -> bool:
super().__init__(tests=tests)
memcached = Memcached(tests=self.tests) memcached = Memcached(tests=self.tests)
self.plain_config = f""" self.plain_config = f"""
setup {{ module_load "mod_memcached"; }} setup {{ module_load "mod_memcached"; }}
@ -77,4 +75,6 @@ memcache = {{
}}); }});
}}; }};
""" """
self.tests.add_service(memcached) self.tests.add_service(memcached)
return True

View File

@ -51,9 +51,7 @@ class Test(base.ModuleTest):
run_scgi; run_scgi;
""" """
def __init__(self, *, tests: base.Tests) -> None: def feature_check(self) -> bool:
super().__init__(tests=tests)
scgi = SCGI(tests=self.tests) scgi = SCGI(tests=self.tests)
self.plain_config = f""" self.plain_config = f"""
setup {{ module_load "mod_scgi"; }} setup {{ module_load "mod_scgi"; }}
@ -62,4 +60,6 @@ run_scgi = {{
core.wsgi ( "/scgi", {{ scgi "unix:{scgi.sockfile}"; }} ); core.wsgi ( "/scgi", {{ scgi "unix:{scgi.sockfile}"; }} );
}}; }};
""" """
self.tests.add_service(scgi) self.tests.add_service(scgi)
return True

View File

@ -1,126 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import os
import socket
import sys
import typing
sys.path.append(os.path.dirname(__file__))
import pylt.fastcgi
class TestAppRequest(pylt.fastcgi.ApplicationRequest):
def __init__(self, *, req_id: int, keep_alive: bool) -> None:
super().__init__(req_id=req_id, keep_alive=keep_alive)
self._process: typing.Optional[asyncio.subprocess.Process] = None
self._task_stdout: typing.Optional[asyncio.Task] = None
self._task_stderr: typing.Optional[asyncio.Task] = None
self._task_wait: typing.Optional[asyncio.Task] = None
async def received_params(self) -> None:
target = self.params.get(b'INTERPRETER') or self.params.get(b'SCRIPT_FILENAME')
if not target:
await self.write_stdout(b"Status: 500\r\nContent-Type: text/plain\r\n\r\n")
await self.finish()
return
env = dict(self.params)
env[b'PATH'] = os.getenvb(b'PATH', b'')
self._process = process = await asyncio.create_subprocess_exec(
target,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
self._task_stdout = asyncio.Task(self._handle_task_stdout(process))
self._task_stderr = asyncio.Task(self._handle_task_stderr(process))
self._task_wait = asyncio.Task(self._handle_task_wait(process))
async def recv_stdin(self, data: bytes) -> None:
if not self._process or not self._process.stdin:
return
try:
if data:
self._process.stdin.write(data)
await self._process.stdin.drain()
else:
self._process.stdin.close()
await self._process.stdin.wait_closed()
except IOError:
stdin = self._process.stdin
self._process.stdin = None
try:
stdin.close()
except IOError:
pass
async def _handle_task_stdout(self, process: asyncio.subprocess.Process):
stdout = process.stdout
try:
if not stdout:
return
while True:
try:
chunk = await stdout.read(8192)
if not chunk:
return
await self.write_stdout(chunk)
except EOFError:
return
finally:
await self.write_stdout(b'')
async def _handle_task_stderr(self, process: asyncio.subprocess.Process):
stderr = process.stderr
try:
if not stderr:
return
while True:
try:
chunk = await stderr.read(8192)
if not chunk:
return
await self.write_stderr(chunk)
except EOFError:
return
finally:
await self.write_stderr(b'')
async def _handle_task_wait(self, process: asyncio.subprocess.Process):
status = await process.wait()
assert self._task_stdout
await self._task_stdout
assert self._task_stderr
await self._task_stderr
self._process = None
await self.finish(status=status)
async def handle_fcgi_cgi(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
print("fcgi-cgi: Incoming connection", flush=True)
app = pylt.fastcgi.Application(reader=reader, writer=writer, request_cls=TestAppRequest)
await app.run()
async def main() -> None:
sock = socket.socket(fileno=0)
if sock.type == socket.AF_UNIX:
server = await asyncio.start_unix_server(handle_fcgi_cgi, sock=sock, start_serving=False)
else:
server = await asyncio.start_server(handle_fcgi_cgi, sock=sock, start_serving=False)
addr = server.sockets[0].getsockname()
print(f'Serving on {addr}', flush=True)
async with server:
await server.serve_forever()
try:
asyncio.run(main())
except KeyboardInterrupt:
pass