Discuss / Python / 稍微改了改 coroweb.py 和 app.py,并添加了注释,Day 7 测试没问题

稍微改了改 coroweb.py 和 app.py,并添加了注释,Day 7 测试没问题

Topic source

xian_wen

#1 Created at ... [Delete] [Delete and Lock User]

【coroweb.py 修改】:

  1. fn = asyncio.coroutine(fn)将普通函数转化为协程,这个不加目前运行也没啥问题。官网Router.add_route()中提到Pay attention please: handler is converted to coroutine internally when it is a regular function,貌似会自动进行转化。

【app.py 修改】:

  1. 由于loop Deprecated since version 3.5,故将所有用到loop的地方都给删了,包括 orm.py。
  2. 采用web.run_app()代替loop.create_server()loop.run_until_complete(),随着而来的问题是数据库连接池的创建必须放到协程中运行。解决办法:单独定义一个协程init_db(),然后绑定到app.on_startup,使得 app 启动后,数据库连接池随之创建。
# coroweb.py

# -*- coding: utf-8 -*-
# @author xian_wen
# @date 6/1/2021 12:26 PM

# import asyncio
import functools
import inspect
import logging
import os

from urllib import parse

from aiohttp import web

from apis import APIError


def get(path):
    """
    Define decorator @get('/path')

    :param path:
    :return:
    """

    def decorator(func):
        # 使得 wrapper.__name__ = func.__name__
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        wrapper.__method__ = 'GET'
        wrapper.__route__ = path
        return wrapper

    return decorator


def post(path):
    """
    Define decorator @post('/path')

    :param path:
    :return:
    """

    def decorator(func):
        # 使得 wrapper.__name__ = func.__name__
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        wrapper.__method__ = 'POST'
        wrapper.__route__ = path
        return wrapper

    return decorator


def get_required_kwargs(fn):
    """
    获取函数命名关键字参数,且非默认参数

    :param fn: function
    :return:
    """
    args = []
    # 获取函数 fn 的参数,ordered mapping
    params = inspect.signature(fn).parameters
    for name, param in params.items():
        # * 或者 *args 后面的参数,且没有默认值
        if param.kind == param.KEYWORD_ONLY and param.default is param.empty:
            args.append(name)
    return tuple(args)


def get_named_kwargs(fn):
    """
    获取函数命名关键字参数

    :param fn: function
    :return:
    """
    args = []
    # 获取函数 fn 的参数,ordered mapping
    params = inspect.signature(fn).parameters
    for name, param in params.items():
        # * 或者 *args 后面的参数
        if param.kind == param.KEYWORD_ONLY:
            args.append(name)
    return tuple(args)


def has_named_kwarg(fn):
    """
    判断是否有命名关键字参数

    :param fn: function
    :return:
    """
    # 获取函数 fn 的参数,ordered mapping
    params = inspect.signature(fn).parameters
    for name, param in params.items():
        # * 或者 *args 后面的参数
        if param.kind == param.KEYWORD_ONLY:
            return True


def has_var_kwarg(fn):
    """
    判断是否有关键字参数

    :param fn: function
    :return:
    """
    # 获取函数 fn 的参数,ordered mapping
    params = inspect.signature(fn).parameters
    for name, param in params.items():
        # **args 后面的参数
        if param.kind == param.VAR_KEYWORD:
            return True


def has_request_arg(fn):
    """
    判断是否有请求参数

    :param fn: function
    :return:
    """
    # 获取函数 fn 的签名
    sig = inspect.signature(fn)
    # 获取函数 fn 的参数,ordered mapping
    params = sig.parameters
    found = False
    for name, param in params.items():
        if name == 'request':
            found = True
            continue
        if found and (param.kind is not param.VAR_POSITIONAL and
                      param.kind is not param.KEYWORD_ONLY and
                      param.kind is not param.VAR_KEYWORD):
            # fn(*args, **kwargs),fn 为 fn.__name__,(*args, **kwargs) 为 sig
            raise ValueError(
                'Request parameter must be the last named parameter in function: %s%s' % (fn.__name__, str(sig)))
    return found


class RequestHandler(object):

    def __init__(self, app, fn):
        self.__app = app
        self.__func = fn
        self.__has_request_arg = has_request_arg(fn)
        self.__has_var_kwarg = has_var_kwarg(fn)
        self.__has_named_kwarg = has_named_kwarg(fn)
        self.__named_kwargs = get_named_kwargs(fn)
        self.__required_kwargs = get_required_kwargs(fn)

    # Make RequestHandler callable
    async def __call__(self, request):
        kwargs = None
        if self.__has_var_kwarg or self.__has_named_kwarg or self.__required_kwargs:
            if request.method == 'POST':
                if not request.content_type:
                    return web.HTTPBadRequest(text='Missing Content-Type.')
                ct = request.content_type.lower()
                # JSON 数据格式
                if ct.startswith('application/json'):
                    # Read request body decoded as json
                    params = await request.json()
                    if not isinstance(params, dict):
                        return web.HTTPBadRequest(text='JSON body must be dict object.')
                    kwargs = params
                # form 表单数据被编码为 key/value 格式发送到服务器(表单默认的提交数据的格式)
                elif ct.startswith('application/x-www-form-urlencoded') or ct.startswith('multipart/form-data'):
                    # Read POST parameters from request body
                    params = await request.post()
                    kwargs = dict(**params)
                else:
                    return web.HTTPBadRequest(text='Unsupported Content-Type: %s' % request.content_type)
            if request.method == 'GET':
                # The query string in the URL, e.g., id=10
                qs = request.query_string
                if qs:
                    kwargs = dict()
                    # {'id': ['10']}
                    for k, v in parse.parse_qs(qs, True).items():
                        kwargs[k] = v[0]
        if kwargs is None:
            kwargs = dict(**request.match_info)
        else:
            if not self.__has_var_kwarg and self.__named_kwargs:
                # Remove all unnamed kwargs
                copy = dict()
                for name in self.__named_kwargs:
                    if name in kwargs:
                        copy[name] = kwargs[name]
                kwargs = copy
            # Check named kwargs
            for k, v in request.match_info.items():
                if k in kwargs:
                    logging.warning('Duplicate arg name in named kwargs and kwargs: %s' % k)
                kwargs[k] = v
        if self.__has_request_arg:
            kwargs['request'] = request
        # Check required kwargs
        if self.__required_kwargs:
            for name in self.__required_kwargs:
                if name not in kwargs:
                    return web.HTTPBadRequest(text='Missing argument: %s' % name)
        logging.info('Call with kwargs: %s' % str(kwargs))
        try:
            r = await self.__func(**kwargs)
            return r
        except APIError as e:
            return dict(error=e.error, data=e.data, message=e.message)


def add_static(app):
    # /www/static
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'static')
    # Add a router and a handler for returning static files
    # Development only, in production, use web servers like nginx or apache
    app.router.add_static('/static/', path)
    logging.info('Add static %s => %s' % ('/static/', path))


def add_route(app, fn):
    method = getattr(fn, '__method__', None)
    path = getattr(fn, '__route__', None)
    if path is None or method is None:
        raise ValueError('@get or @post not defined in %s.' % str(fn))
    # if not asyncio.iscoroutinefunction(fn) and not inspect.isgeneratorfunction(fn):
    #     fn = asyncio.coroutine(fn)
    logging.info(
        # GET / => fn(*args, **kwargs)
        'Add route %s %s => %s(%s)' % (method, path, fn.__name__, ', '.join(inspect.signature(fn).parameters.keys())))
    # Attention: handler is converted to coroutine internally when it is a regular function
    app.router.add_route(method, path, RequestHandler(app, fn))


def add_routes(app, module_name):
    # For package.module, n = 7
    # For module, n = -1
    n = module_name.rfind('.')
    if n == -1:
        # Import module
        mod = __import__(module_name, globals(), locals())
    else:
        # For package.module, name = module
        name = module_name[n + 1:]
        # Import package.module, the same as 'from package import module', fromlist = [module]
        mod = getattr(__import__(module_name[:n], globals(), locals(), [name]), name)
    # Directory of attributes of module
    for attr in dir(mod):
        if attr.startswith('__'):
            continue
        fn = getattr(mod, attr)
        if callable(fn):
            method = getattr(fn, '__method__', None)
            path = getattr(fn, '__route__', None)
            if method and path:
                add_route(app, fn)
# app.py

# -*- coding: utf-8 -*-
# @author xian_wen
# @date 5/26/2021 2:06 PM

import json
import logging
import os
import time
from datetime import datetime

from aiohttp import web
from jinja2 import Environment, FileSystemLoader

import orm
from coroweb import add_routes, add_static
from config import configs

logging.basicConfig(level=logging.INFO)


def init_jinja2(app, **kwargs):
    logging.info('Init jinja2...')
    options = dict(
        autoescape=kwargs.get('autoescape', True),
        block_start_string=kwargs.get('block_start_string', '{%'),
        block_end_string=kwargs.get('block_end_string', '%}'),
        variable_start_string=kwargs.get('variable_start_string', '{{'),
        variable_end_string=kwargs.get('variable_end_string', '}}'),
        auto_reload=kwargs.get('auto_reload', True)
    )
    path = kwargs.get('path', None)
    if path is None:
        # /www/templates
        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')
    logging.info('Set jinja2 template path: %s' % path)
    # Load templates from a directory in the file system
    env = Environment(loader=FileSystemLoader(path), **options)
    filters = kwargs.get('filters', None)
    if filters is not None:
        # Filters are Python functions
        for name, f in filters.items():
            env.filters[name] = f
    app['__templating__'] = env


async def logger_factory(app, handler):
    async def logger(request):
        logging.info('Request: %s %s' % (request.method, request.path))
        return await handler(request)

    return logger


async def data_factory(app, handler):
    async def parse_data(request):
        # JSON 数据格式
        if request.content_type.startswith('application/json'):
            # Read request body decoded as json
            request.__data__ = await request.json()
            logging.info('Request json: %s' % str(request.__data__))
        # form 表单数据被编码为 key/value 格式发送到服务器(表单默认的提交数据的格式)
        elif request.content_type.startswith('application/x-www-form-urlencoded'):
            # Read POST parameters from request body
            request.__data__ = await request.post()
            logging.info('Request form: %s' % str(request.__data__))
        return await handler(request)

    return parse_data


async def response_factory(app, handler):
    async def response(request):
        logging.info('Response handler...')
        r = await handler(request)
        # The base class for the HTTP response handling
        if isinstance(r, web.StreamResponse):
            return r
        if isinstance(r, bytes):
            resp = web.Response(body=r)
            # 二进制流数据(如常见的文件下载)
            resp.content_type = 'application/octet-stream'
            return resp
        if isinstance(r, str):
            if r.startswith('redirect:'):
                return web.HTTPFound(r[9:])
            resp = web.Response(body=r.encode('utf-8'))
            # HTML 格式
            resp.content_type = 'text/html; charset=UTF-8'
            return resp
        # Response classes are dict like objects
        if isinstance(r, dict):
            template = r.get('__template__')
            if template is None:
                resp = web.Response(
                    # ensure_ascii: if false then return value can contain non-ASCII characters
                    # __dict__: store an object’s (writable) attributes
                    # 序列化 r 为 json 字符串,default 把任意一个对象变成一个可序列为 JSON 的对象
                    body=json.dumps(r, ensure_ascii=False, default=lambda o: o.__dict__).encode('utf-8'))
                # JSON 数据格式
                resp.content_type = 'application/json; charset=UTF-8'
                return resp
            else:
                # app[__templating__] 是一个 Environment 对象,加载模板,渲染模板
                resp = web.Response(body=app['__templating__'].get_template(template).render(**r).encode('utf-8'))
                resp.content_type = 'text/html; charset=UTF-8'
                return resp
        # Status Code
        if isinstance(r, int) and 100 <= r < 600:
            return web.Response(status=r)
        # Status Code and Reason Phrase
        if isinstance(r, tuple) and len(r) == 2:
            t, m = r
            if isinstance(t, int) and 100 <= t < 600:
                # 1xx: Informational - Request received, continuing process
                # 2xx: Success - The action was successfully received, understood, and accepted
                # 3xx: Redirection - Further action must be taken in order to complete the request
                # 4xx: Client Error - The request contains bad syntax or cannot be fulfilled
                # 5xx: Server Error - The server failed to fulfill an apparently valid request
                return web.Response(status=t, reason=str(m))
        # Default
        resp = web.Response(body=str(r).encode('utf-8'))
        # 纯文本格式
        resp.content_type = 'text/plain; charset=UTF-8'
        return resp

    return response


def datetime_filter(t):
    delta = int(time.time() - t)
    if delta < 60:  # 1 min
        return u'1分钟前'
    if delta < 3600:  # 1 h
        # // 表示商取整,如 3 / 2 = 1.5,3 // 2 = 1
        return u'%s分钟前' % (delta // 60)
    if delta < 86400:  # 24 h
        return u'%s小时前' % (delta // 3600)
    if delta < 604800:  # 7 days
        return u'%s天前' % (delta // 86400)
    dt = datetime.fromtimestamp(t)
    return u'%s年%s月%s日' % (dt.year, dt.month, dt.day)


async def init_db(app):
    # If on Linux, use another user instead of 'root'
    await orm.create_pool(
        host=configs.db.host,
        port=configs.db.port,
        user=configs.db.user,
        password=configs.db.password,
        db=configs.db.database
    )


app = web.Application(middlewares=[
    logger_factory,
    response_factory
])
init_jinja2(app, filters=dict(datatime=datetime_filter))
add_routes(app, 'handlers')
add_static(app)
app.on_startup.append(init_db)
web.run_app(app, host='localhost', port=9000)

  • 1

Reply