Compare commits
2 Commits
424e1a37f8
...
c8bf635551
Author | SHA1 | Date | |
---|---|---|---|
c8bf635551 | |||
98ea1dc7de |
@ -42,9 +42,32 @@ my %text_utf8 = map { $_ => 1 } qw(
|
|||||||
|
|
||||||
# map extension to hash which maps types to the type they should be replaced with
|
# map extension to hash which maps types to the type they should be replaced with
|
||||||
my %manual_conflicts_resolve = (
|
my %manual_conflicts_resolve = (
|
||||||
|
'.asn' => {
|
||||||
|
'chemical/x-ncbi-asn1-spec' => 'application/octet-stream',
|
||||||
|
'chemical/x-ncbi-asn1' => 'application/octet-stream',
|
||||||
|
},
|
||||||
|
'.otf' => {
|
||||||
|
'application/font-sfnt' => 'font/otf',
|
||||||
|
'font/sfnt' => 'font/otf',
|
||||||
|
'font/ttf' => 'font/otf',
|
||||||
|
},
|
||||||
|
'.pcx' => {
|
||||||
|
'image/vnd.zbrush.pcx' => 'image/pcx',
|
||||||
|
},
|
||||||
|
'.png' => {
|
||||||
|
'image/vnd.mozilla.apng' => 'image/png',
|
||||||
|
},
|
||||||
'.ra' => {
|
'.ra' => {
|
||||||
'audio/x-pn-realaudio' => 'audio/x-realaudio',
|
'audio/x-pn-realaudio' => 'audio/x-realaudio',
|
||||||
},
|
},
|
||||||
|
'.ttf' => {
|
||||||
|
'application/font-sfnt' => 'font/ttf',
|
||||||
|
'font/sfnt' => 'font/ttf',
|
||||||
|
'font/otf' => 'font/ttf',
|
||||||
|
},
|
||||||
|
'.woff' => {
|
||||||
|
'application/font-woff' => 'font/woff',
|
||||||
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
open MIMETYPES, "/etc/mime.types" or die "Can't open mime.types: $!";
|
open MIMETYPES, "/etc/mime.types" or die "Can't open mime.types: $!";
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,62 +1,65 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import asyncore
|
import asyncio
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
class MemcacheEntry:
|
class MemcacheEntry:
|
||||||
def __init__(self, flags, exptime, data, cas):
|
def __init__(self, flags: bytes, exptime: bytes, data: bytes, cas: bytes):
|
||||||
self.flags = flags
|
self.flags: int = self._flags(flags)
|
||||||
self.data = data
|
self.data: bytes = data
|
||||||
self.cas = cas
|
self.cas: bytes = cas
|
||||||
self.setExptime(exptime)
|
self.expire: typing.Optional[float] = self._ttl(exptime)
|
||||||
|
|
||||||
def _ttl(self, str):
|
@staticmethod
|
||||||
v = int(str)
|
def _flags(flags: bytes) -> int:
|
||||||
if v < 0: raise ValueError('exptime not an unsigned integer')
|
v = int(flags)
|
||||||
if v > 0 and v < 365*24*3600: v += time.time()
|
if v < 0 or v >= 2**32: raise ValueError('flags not an unsigned 32-bit integer')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def setExptime(self, exptime):
|
@staticmethod
|
||||||
exptime = self._ttl(exptime)
|
def _ttl(value: bytes) -> typing.Optional[float]:
|
||||||
if exptime > 0:
|
v = int(value)
|
||||||
self.expire = exptime
|
if v < 0: raise ValueError('exptime not an unsigned integer')
|
||||||
|
if v > 0 and v < 365*24*3600:
|
||||||
|
return v + time.time()
|
||||||
else:
|
else:
|
||||||
self.expire = None
|
return None
|
||||||
|
|
||||||
def flush(self, exptime):
|
def setExptime(self, exptime: bytes):
|
||||||
exptime = self._ttl(exptime)
|
self.expire = self._ttl(exptime)
|
||||||
if self.expire == None or self.expire > expire:
|
|
||||||
self.expire = expire
|
def flush(self, exptime: float):
|
||||||
|
# make sure entry expires at `exptime` (or before)
|
||||||
|
if self.expire is None or self.expire > exptime:
|
||||||
|
self.expire = exptime
|
||||||
|
|
||||||
def expired(self):
|
def expired(self):
|
||||||
if self.expire != None: return self.expire < time.time()
|
return self.expire != None and self.expire < time.time()
|
||||||
return False
|
|
||||||
|
|
||||||
class MemcacheDB:
|
class MemcacheDB:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.d = dict()
|
self.d = dict()
|
||||||
self.cas = random.randint(0, 2**64-1)
|
self._cas = random.randint(0, 2**64-1)
|
||||||
|
|
||||||
def _uint64value(self, str):
|
@staticmethod
|
||||||
|
def _uint64value(str):
|
||||||
v = int(str)
|
v = int(str)
|
||||||
if v < 0 or v >= 2**64: raise ValueError('not an unsigned 64-bit integer')
|
if v < 0 or v >= 2**64: raise ValueError('not an unsigned 64-bit integer')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def _flags(self, str):
|
def _next_cas(self) -> bytes:
|
||||||
v = int(str)
|
cas = self._cas
|
||||||
if v < 0 or v >= 2**32: raise ValueError('flags not an unsigned 32-bit integer')
|
self._cas = (cas + 1) % 2**64
|
||||||
return v
|
return b'%d' % cas
|
||||||
|
|
||||||
def _next_cas(self):
|
def get(self, key: bytes):
|
||||||
cas = self.cas
|
|
||||||
self.cas = (self.cas + 1) % 2**64
|
|
||||||
return cas
|
|
||||||
|
|
||||||
def get(self, key):
|
|
||||||
if not key in self.d: return None
|
if not key in self.d: return None
|
||||||
entry = self.d[key]
|
entry = self.d[key]
|
||||||
if entry.expired():
|
if entry.expired():
|
||||||
@ -64,152 +67,156 @@ class MemcacheDB:
|
|||||||
return None
|
return None
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
def set(self, key, flags, exptime, data):
|
def set(self, key: bytes, flags: bytes, exptime: bytes, data: bytes):
|
||||||
self.d[key] = MemcacheEntry(flags, exptime, data, self._next_cas())
|
self.d[key] = MemcacheEntry(flags, exptime, data, self._next_cas())
|
||||||
return "STORED"
|
return b"STORED"
|
||||||
|
|
||||||
def add(self, key, flags, exptime, data):
|
def add(self, key: bytes, flags: bytes, exptime: bytes, data: bytes):
|
||||||
if None != self.get(key): return "NOT_STORED"
|
if None != self.get(key): return b"NOT_STORED"
|
||||||
self.d[key] = MemcacheEntry(flags, exptime, data, self._next_cas())
|
self.d[key] = MemcacheEntry(flags, exptime, data, self._next_cas())
|
||||||
return "STORED"
|
return b"STORED"
|
||||||
|
|
||||||
def replace(self, key, flags, exptime, data):
|
def replace(self, key: bytes, flags: bytes, exptime: bytes, data: bytes):
|
||||||
if None == self.get(key): return "NOT_STORED"
|
if None == self.get(key): return b"NOT_STORED"
|
||||||
self.d[key] = MemcacheEntry(flags, exptime, data, self._next_cas())
|
self.d[key] = MemcacheEntry(flags, exptime, data, self._next_cas())
|
||||||
return "STORED"
|
return b"STORED"
|
||||||
|
|
||||||
def append(self, key, data):
|
def append(self, key: bytes, data: bytes):
|
||||||
entry = self.get(key)
|
entry = self.get(key)
|
||||||
if None == entry: return "NOT_FOUND"
|
if None == entry: return b"NOT_FOUND"
|
||||||
entry.data += data
|
entry.data += data
|
||||||
entry.cas = _next_cas()
|
entry.cas = self._next_cas()
|
||||||
return "STORED"
|
return b"STORED"
|
||||||
|
|
||||||
def prepend(self, key, data):
|
def prepend(self, key: bytes, data: bytes):
|
||||||
entry = self.get(key)
|
entry = self.get(key)
|
||||||
if None == entry: return "NOT_FOUND"
|
if None == entry: return b"NOT_FOUND"
|
||||||
entry.data = data + entry.data
|
entry.data = data + entry.data
|
||||||
entry.cas = _next_cas()
|
entry.cas = self._next_cas()
|
||||||
return "STORED"
|
return b"STORED"
|
||||||
|
|
||||||
def cas(self, key, flags, exptime, cas, data):
|
def cas(self, key: bytes, flags: bytes, exptime: bytes, cas: bytes, data: bytes):
|
||||||
entry = self.get(key)
|
entry = self.get(key)
|
||||||
if None == entry: return "NOT_FOUND"
|
if None == entry: return b"NOT_FOUND"
|
||||||
if entry.cas != cas: return "EXISTS"
|
if entry.cas != cas: return b"EXISTS"
|
||||||
self.d[key] = MemcacheEntry(flags, exptime, data, self._next_cas())
|
self.d[key] = MemcacheEntry(flags, exptime, data, self._next_cas())
|
||||||
return "STORED"
|
return b"STORED"
|
||||||
|
|
||||||
def delete(self, key):
|
def delete(self, key: bytes):
|
||||||
entry = self.get(key)
|
entry = self.get(key)
|
||||||
if None == entry: return "NOT_FOUND"
|
if None == entry: return b"NOT_FOUND"
|
||||||
self.d.pop(key)
|
self.d.pop(key)
|
||||||
return "DELETED"
|
return b"DELETED"
|
||||||
|
|
||||||
def incr(self, key, value):
|
def incr(self, key: bytes, value: bytes):
|
||||||
try:
|
try:
|
||||||
value = _uint64value(value)
|
value = self._uint64value(value)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return "CLIENT_ERROR " + str(e)
|
return b"CLIENT_ERROR %s" % str(e).encode('utf-8')
|
||||||
entry = self.get(key)
|
entry = self.get(key)
|
||||||
if None == entry: return "NOT_FOUND"
|
if None == entry: return b"NOT_FOUND"
|
||||||
try:
|
try:
|
||||||
v = _uint64value(entry.data)
|
v = self._uint64value(entry.data)
|
||||||
v = (v + value) % 2**64
|
v = (v + value) % 2**64
|
||||||
entry.data = str(v)
|
entry.data = str(v)
|
||||||
entry.cas = _next_cas()
|
entry.cas = self._next_cas()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return "SERVER_ERROR " + str(e)
|
return b"SERVER_ERROR %s" % str(e).encode('utf-8')
|
||||||
return entry.data
|
return entry.data
|
||||||
|
|
||||||
def decr(self, key, value):
|
def decr(self, key: bytes, value: bytes):
|
||||||
try:
|
try:
|
||||||
value = _uint64value(value)
|
value = self._uint64value(value)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return "CLIENT_ERROR " + str(e)
|
return b"CLIENT_ERROR %s" % str(e).encode('utf-8')
|
||||||
entry = self.get(key)
|
entry = self.get(key)
|
||||||
if None == entry: return "NOT_FOUND"
|
if None == entry: return b"NOT_FOUND"
|
||||||
try:
|
try:
|
||||||
v = _uint64value(entry.data)
|
v = self._uint64value(entry.data)
|
||||||
v = v - value
|
v = v - value
|
||||||
if v < 0: v = 0
|
if v < 0: v = 0
|
||||||
entry.data = str(v)
|
entry.data = str(v)
|
||||||
entry.cas = _next_cas()
|
entry.cas = self._next_cas()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return "SERVER_ERROR " + str(e)
|
return b"SERVER_ERROR %s" % str(e).encode('utf-8')
|
||||||
return entry.data
|
return entry.data
|
||||||
|
|
||||||
def touch(self, key, exptime):
|
def touch(self, key: bytes, exptime: bytes):
|
||||||
entry = self.get(key)
|
entry = self.get(key)
|
||||||
if None == entry: return "NOT_FOUND"
|
if None == entry: return b"NOT_FOUND"
|
||||||
entry.setExptime(exptime)
|
entry.setExptime(exptime)
|
||||||
return "TOUCHED"
|
return b"TOUCHED"
|
||||||
|
|
||||||
def stats(self):
|
def stats(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def flush_all(self, exptime = None):
|
def flush_all(self, exptime: typing.Optional[bytes] = None):
|
||||||
if exptime == None:
|
if exptime is None:
|
||||||
self.d = dict()
|
self.d = dict()
|
||||||
else:
|
else:
|
||||||
for key in self.d.keys:
|
expire_at = MemcacheEntry._ttl(exptime) or time.time()
|
||||||
|
for key in self.d.keys():
|
||||||
entry = self.get(key)
|
entry = self.get(key)
|
||||||
if entry != None: entry.flush(exptime)
|
if entry != None: entry.flush(expire_at)
|
||||||
return "OK"
|
return b"OK"
|
||||||
|
|
||||||
def version(self):
|
def version(self):
|
||||||
return "VERSION python memcached stub 0.1"
|
return b"VERSION python memcached stub 0.1"
|
||||||
|
|
||||||
def verbosity(self, level):
|
def verbosity(self, level: bytes):
|
||||||
return "OK"
|
return b"OK"
|
||||||
|
|
||||||
class MemcachedHandler(asyncore.dispatcher_with_send):
|
|
||||||
def __init__(self, sock, db):
|
class MemcachedHandler:
|
||||||
asyncore.dispatcher_with_send.__init__(self, sock)
|
def __init__(self, *, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, db: MemcacheDB):
|
||||||
|
self.reader = reader
|
||||||
|
self.writer = writer
|
||||||
self.db = db
|
self.db = db
|
||||||
self.data = ''
|
self.data = b''
|
||||||
self.want_binary = None
|
self.want_binary = None
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
def _server_error(self, msg):
|
def _server_error(self, msg: str):
|
||||||
self.send('SERVER_ERROR ' + msg + '\r\n')
|
self.writer.write(b'SERVER_ERROR %s\r\n' % msg.encode('utf-8'))
|
||||||
self.data = ''
|
self.data = b''
|
||||||
self.close()
|
self.closed = True
|
||||||
|
|
||||||
def _client_error(self, msg):
|
def _client_error(self, msg: str):
|
||||||
self.send('CLIENT_ERROR ' + msg + '\r\n')
|
self.writer.write(b'CLIENT_ERROR %s\r\n' % msg.encode('utf-8'))
|
||||||
self.data = ''
|
self.data = b''
|
||||||
self.close()
|
self.closed = True
|
||||||
|
|
||||||
def _error(self, msg):
|
def _error(self):
|
||||||
self.send('ERROR\r\n')
|
self.writer.write(b'ERROR\r\n')
|
||||||
self.data = ''
|
self.data = b''
|
||||||
self.close()
|
self.closed = True
|
||||||
|
|
||||||
def _handle_binary(self, b):
|
def _handle_binary(self, b):
|
||||||
args = self.args + [b]
|
args = self.args
|
||||||
cmd = self.cmd
|
cmd = self.cmd
|
||||||
noreply = self.noreply
|
noreply = self.noreply
|
||||||
self.cmd = self.args = self.noreply = None
|
self.cmd = self.args = self.noreply = None
|
||||||
r = getattr(self.db, cmd).__call__(*args)
|
r = (getattr(self.db, cmd))(*args, b)
|
||||||
if not noreply: self.send('%s\r\n' % r)
|
if not noreply: self.writer.write(b'%s\r\n' % r)
|
||||||
|
|
||||||
def _handle_line(self, line):
|
def _handle_line(self, line):
|
||||||
args = line.split()
|
args = line.split()
|
||||||
if len(args) == 0: return _client_error("empty command")
|
if len(args) == 0: return self._client_error("empty command")
|
||||||
cmd = args[0]
|
cmd = args[0].decode('ascii')
|
||||||
args = args[1:]
|
args = args[1:]
|
||||||
noreply = False
|
noreply = False
|
||||||
if args[-1] == "noreply":
|
if args[-1] == b"noreply":
|
||||||
args.pop()
|
args.pop()
|
||||||
noreply = True
|
noreply = True
|
||||||
if cmd in ['set', 'add', 'replace', 'append', 'prepend']:
|
if cmd in ['set', 'add', 'replace', 'append', 'prepend']:
|
||||||
if len(args) != 4: return _client_error("wrong number %i of arguments for command" % 4)
|
if len(args) != 4: return self._client_error("wrong number %i of arguments for command" % 4)
|
||||||
self.want_binary = int(args[3])
|
self.want_binary = int(args[3])
|
||||||
if self.want_binary < 0: return _client_error("negative bytes length")
|
if self.want_binary < 0: return self._client_error("negative bytes length")
|
||||||
self.args = args[:3]
|
self.args = args[:3]
|
||||||
self.cmd = cmd
|
self.cmd = cmd
|
||||||
self.noreply = noreply
|
self.noreply = noreply
|
||||||
elif cmd == 'cas':
|
elif cmd == 'cas':
|
||||||
if len(args) != 5: return _client_error("wrong number %i of arguments for command" % 5)
|
if len(args) != 5: return self._client_error("wrong number %i of arguments for command" % 5)
|
||||||
self.want_binary = args[3]
|
self.want_binary = args[3]
|
||||||
args = args[:3] + args[4:]
|
args = args[:3] + args[4:]
|
||||||
self.cmd = cmd
|
self.cmd = cmd
|
||||||
@ -218,21 +225,21 @@ class MemcachedHandler(asyncore.dispatcher_with_send):
|
|||||||
for key in args:
|
for key in args:
|
||||||
entry = self.db.get(key)
|
entry = self.db.get(key)
|
||||||
if entry != None:
|
if entry != None:
|
||||||
self.send('VALUE %s %s %s\r\n%s\r\n' % (key, entry.flags, len(entry.data), entry.data))
|
self.writer.write(b'VALUE %s %d %d\r\n%s\r\n' % (key, entry.flags, len(entry.data), entry.data))
|
||||||
self.send('END\r\n')
|
self.writer.write(b'END\r\n')
|
||||||
elif cmd == 'gets':
|
elif cmd == 'gets':
|
||||||
for key in args:
|
for key in args:
|
||||||
entry = self.db.get(key)
|
entry = self.db.get(key)
|
||||||
if entry != None:
|
if entry != None:
|
||||||
self.send('VALUE %s %s %s %s\r\n%s\r\n' % (key, entry.flags, len(entry.data), entry.cas, entry.data))
|
self.writer.write(b'VALUE %s %d %d %d\r\n%s\r\n' % (key, entry.flags, len(entry.data), entry.cas, entry.data))
|
||||||
self.send('END\r\n')
|
self.writer.write(b'END\r\n')
|
||||||
elif cmd == 'stats':
|
elif cmd == 'stats':
|
||||||
for (name, value) in self.db.stats():
|
for (name, value) in self.db.stats():
|
||||||
self.send('STAT ' + name + ' ' + value + '\r\n')
|
self.writer.write(b'STAT %s %s\r\n' % (name, value))
|
||||||
self.send('END\r\n')
|
self.writer.write(b'END\r\n')
|
||||||
elif cmd in ['delete', 'incr', 'decr', 'touch', 'flush_all', 'version', 'verbosity']:
|
elif cmd in ['delete', 'incr', 'decr', 'touch', 'flush_all', 'version', 'verbosity']:
|
||||||
r = getattr(self.db, cmd).__call__(*args)
|
r = (getattr(self.db, cmd))(*args)
|
||||||
if not noreply: self.send(r + '\r\n')
|
if not noreply: self.writer.write(r + b'\r\n')
|
||||||
else:
|
else:
|
||||||
return self._error()
|
return self._error()
|
||||||
|
|
||||||
@ -241,49 +248,64 @@ class MemcachedHandler(asyncore.dispatcher_with_send):
|
|||||||
if self.want_binary != None:
|
if self.want_binary != None:
|
||||||
if len(self.data) >= self.want_binary + 2:
|
if len(self.data) >= self.want_binary + 2:
|
||||||
b = self.data[:self.want_binary]
|
b = self.data[:self.want_binary]
|
||||||
if self.data[self.want_binary:self.want_binary+2] != '\r\n':
|
if self.data[self.want_binary:self.want_binary+2] != b'\r\n':
|
||||||
return self._parse_error("wrong termination of binary data")
|
return self._client_error("wrong termination of binary data")
|
||||||
self.data = self.data[self.want_binary+2:]
|
self.data = self.data[self.want_binary+2:]
|
||||||
self._handle_binary(b)
|
self._handle_binary(b)
|
||||||
else:
|
else:
|
||||||
return # wait for more data
|
return # wait for more data
|
||||||
else:
|
else:
|
||||||
pos = self.data.find('\r\n')
|
pos = self.data.find(b'\r\n')
|
||||||
if pos < 0:
|
if pos < 0:
|
||||||
if len(self.data) > 512:
|
if len(self.data) > 512:
|
||||||
return self._parse_error("command too long")
|
return self._client_error("command too long")
|
||||||
return # wait for more data
|
return # wait for more data
|
||||||
l = self.data[:pos]
|
l = self.data[:pos]
|
||||||
self.data = self.data[pos+2:]
|
self.data = self.data[pos+2:]
|
||||||
self._handle_line(l)
|
self._handle_line(l)
|
||||||
|
|
||||||
def handle_read(self):
|
async def handle(self):
|
||||||
self.data += self.recv(8192)
|
while not self.closed:
|
||||||
|
await self.writer.drain()
|
||||||
|
next_buf = await self.reader.read(8192)
|
||||||
|
if len(next_buf) == 0:
|
||||||
|
# received EOF, close immediately.
|
||||||
|
self.writer.close()
|
||||||
|
return
|
||||||
|
self.data += next_buf
|
||||||
try:
|
try:
|
||||||
self._handle_data()
|
self._handle_data()
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
self._client_error("wrong number of arguments for command: %s" % e)
|
self._client_error("wrong number of arguments for command: %s" % e)
|
||||||
print traceback.format_exc()
|
print(traceback.format_exc())
|
||||||
|
# close
|
||||||
|
await self.writer.drain()
|
||||||
|
self.writer.close()
|
||||||
|
await self.writer.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
class MemcachedServer(asyncore.dispatcher):
|
async def main():
|
||||||
|
sock = socket.socket(fileno=0)
|
||||||
|
if sock.type == socket.AF_UNIX:
|
||||||
|
start_server = asyncio.start_unix_server
|
||||||
|
else:
|
||||||
|
start_server = asyncio.start_server
|
||||||
|
db = MemcacheDB()
|
||||||
|
|
||||||
def __init__(self, sock):
|
async def handle_memcache_client(reader, writer):
|
||||||
asyncore.dispatcher.__init__(self)
|
print(f"Memcached: Incoming connection", flush=True)
|
||||||
sock.setblocking(0)
|
await MemcachedHandler(reader=reader, writer=writer, db=db).handle()
|
||||||
self.set_socket(sock)
|
|
||||||
self.accepting = True
|
server = await start_server(handle_memcache_client, sock=sock, start_serving=False)
|
||||||
self.db = MemcacheDB()
|
|
||||||
|
addr = server.sockets[0].getsockname()
|
||||||
|
print(f'Serving on {addr}', flush=True)
|
||||||
|
|
||||||
|
async with server:
|
||||||
|
await server.serve_forever()
|
||||||
|
|
||||||
def handle_accept(self):
|
|
||||||
pair = self.accept()
|
|
||||||
if pair is not None:
|
|
||||||
sock, addr = pair
|
|
||||||
print 'Memcached: Incoming connection'
|
|
||||||
handler = MemcachedHandler(sock, self.db)
|
|
||||||
|
|
||||||
server = MemcachedServer(socket.fromfd(0, socket.AF_UNIX, socket.SOCK_STREAM))
|
|
||||||
try:
|
try:
|
||||||
asyncore.loop()
|
asyncio.run(main())
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
@ -1,65 +1,76 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
import socket
|
import socket
|
||||||
import traceback
|
import traceback
|
||||||
|
import typing
|
||||||
|
|
||||||
servsocket = socket.fromfd(0, socket.AF_UNIX, socket.SOCK_STREAM)
|
|
||||||
|
|
||||||
def parsereq(data):
|
@dataclasses.dataclass
|
||||||
a = data.split(':', 1)
|
class ScgiRequest:
|
||||||
if len(a) != 2: return False
|
headers: typing.Dict[bytes, bytes]
|
||||||
hlen, rem = a
|
body: bytes
|
||||||
hlen = int(hlen)
|
|
||||||
if hlen < 16: raise Exception("invalid request")
|
|
||||||
if len(rem) < hlen + 1: return False
|
|
||||||
if rem[hlen] != ',' or rem[hlen-1] != '\0': raise Exception("invalid request")
|
|
||||||
header = rem[:hlen-1]
|
|
||||||
body = rem[hlen+1:]
|
|
||||||
header = header.split('\0')
|
|
||||||
if len(header) < 4: raise Exception("invalid request: not enough header entries")
|
|
||||||
if header[0] != "CONTENT_LENGTH": raise Exception("invalid request: missing CONTENT_LENGTH")
|
|
||||||
clen = int(header[1])
|
|
||||||
if len(body) < clen: return False
|
|
||||||
env = dict()
|
|
||||||
while len(header) > 0:
|
|
||||||
if len(header) == 1: raise Exception("invalid request: missing value for key")
|
|
||||||
key, value = header[0:2]
|
|
||||||
header = header[2:]
|
|
||||||
if '' == key: raise Exception("invalid request: empty key")
|
|
||||||
if key in env: raise Exception("invalid request: duplicate key")
|
|
||||||
env[key] = value
|
|
||||||
if not 'SCGI' in env or env['SCGI'] != '1':
|
|
||||||
raise Exception("invalid request: missing/broken SCGI=1 header")
|
|
||||||
return {'env': env, 'body': body}
|
|
||||||
|
|
||||||
try:
|
|
||||||
while 1:
|
async def parse_scgi_request(reader: asyncio.StreamReader) -> ScgiRequest:
|
||||||
conn, addr = servsocket.accept()
|
hlen = int((await reader.readuntil(b':'))[:-1])
|
||||||
result_status = 200
|
header_raw = await reader.readexactly(hlen + 1)
|
||||||
result = ''
|
assert len(header_raw) >= 16, "invalid request: too short (< 16)"
|
||||||
|
assert header_raw[-2:] == b'\0,', f"Invalid request: missing header/netstring terminator '\\x00,', got {header_raw[-2:]!r}"
|
||||||
|
header_list = header_raw[:-2].split(b'\0')
|
||||||
|
assert len(header_list) % 2 == 0, f"Invalid request: odd numbers of header entries (must be pairs), got {len(header_list)}"
|
||||||
|
assert header_list[0] == b'CONTENT_LENGTH', f"Invalid request: first header entry must be 'CONTENT_LENGTH', got {header_list[0]!r}"
|
||||||
|
clen = int(header_list[1])
|
||||||
|
headers = {}
|
||||||
|
i = 0
|
||||||
|
while i < len(header_list):
|
||||||
|
key = header_list[i]
|
||||||
|
value = header_list[i+1]
|
||||||
|
i += 2
|
||||||
|
assert not key in headers, f"Invalid request: duplicate header key {key!r}"
|
||||||
|
headers[key] = value
|
||||||
|
assert headers.get(b'SCGI') == b'1', "Invalid request: missing SCGI=1 header"
|
||||||
|
body = await reader.readexactly(clen)
|
||||||
|
return ScgiRequest(headers=headers, body=body)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_scgi(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||||
|
print(f"scgi-envcheck: Incoming connection", flush=True)
|
||||||
try:
|
try:
|
||||||
print 'Accepted connection'
|
req = await parse_scgi_request(reader)
|
||||||
data = ''
|
envvar = req.headers[b'QUERY_STRING']
|
||||||
header = False
|
result = req.headers[envvar]
|
||||||
while not header:
|
|
||||||
newdata = conn.recv(1024)
|
|
||||||
if len(newdata) == 0: raise Exception("invalid request: unexpected EOF")
|
|
||||||
data += newdata
|
|
||||||
header = parsereq(data)
|
|
||||||
envvar = header['env']['QUERY_STRING']
|
|
||||||
result = header['env'][envvar]
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print traceback.format_exc()
|
print(traceback.format_exc())
|
||||||
result_status = 500
|
writer.write(b"Status: 500\r\nContent-Type: text/plain\r\n\r\n" + str(e).encode('utf-8'))
|
||||||
result = str(e)
|
else:
|
||||||
try:
|
writer.write(b"Status: 200\r\nContent-Type: text/plain\r\n\r\n" + result)
|
||||||
conn.sendall("Status: " + str(result_status) + "\r\nContent-Type: text/plain\r\n\r\n")
|
await writer.drain()
|
||||||
conn.sendall(result)
|
writer.close()
|
||||||
conn.close()
|
await writer.wait_closed()
|
||||||
except:
|
|
||||||
print traceback.format_exc()
|
|
||||||
|
async def main():
|
||||||
|
sock = socket.socket(fileno=0)
|
||||||
|
if sock.type == socket.AF_UNIX:
|
||||||
|
start_server = asyncio.start_unix_server
|
||||||
|
else:
|
||||||
|
start_server = asyncio.start_server
|
||||||
|
|
||||||
|
server = await start_server(handle_scgi, sock=sock, start_serving=False)
|
||||||
|
|
||||||
|
addr = server.sockets[0].getsockname()
|
||||||
|
print(f'Serving on {addr}', flush=True)
|
||||||
|
|
||||||
|
async with server:
|
||||||
|
await server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user