|
3 | 3 | import decimal |
4 | 4 | import random |
5 | 5 | import socket |
| 6 | +import unittest |
6 | 7 | import weakref |
7 | 8 |
|
8 | 9 | from uvloop import _testbase as tb |
9 | 10 |
|
10 | 11 |
|
11 | 12 | class _Protocol(asyncio.Protocol): |
12 | | - def __init__(self, *, loop=None): |
| 13 | + def __init__(self, cvar, *, loop=None): |
| 14 | + self.cvar = cvar |
| 15 | + self.transport = None |
| 16 | + self.connection_made_fut = asyncio.Future(loop=loop) |
| 17 | + self.connection_lost_ctx = None |
13 | 18 | self.done = asyncio.Future(loop=loop) |
14 | 19 |
|
| 20 | + def connection_made(self, transport): |
| 21 | + self.transport = transport |
| 22 | + self.connection_made_fut.set_result(self.cvar.get()) |
| 23 | + |
15 | 24 | def connection_lost(self, exc): |
| 25 | + self.connection_lost_ctx = self.cvar.get() |
16 | 26 | if exc is None: |
17 | 27 | self.done.set_result(None) |
18 | 28 | else: |
@@ -141,7 +151,7 @@ async def main(): |
141 | 151 | def test_create_server_protocol_factory_context(self): |
142 | 152 | cvar = contextvars.ContextVar('cvar', default='outer') |
143 | 153 | factory_called_future = self.loop.create_future() |
144 | | - proto = _Protocol(loop=self.loop) |
| 154 | + proto = _Protocol(cvar, loop=self.loop) |
145 | 155 |
|
146 | 156 | def factory(): |
147 | 157 | try: |
@@ -172,6 +182,71 @@ async def test(): |
172 | 182 |
|
173 | 183 | self.loop.run_until_complete(test()) |
174 | 184 |
|
| 185 | + def test_create_server_connection_made_lost(self): |
| 186 | + cvar = contextvars.ContextVar('cvar', default='outer') |
| 187 | + proto = _Protocol(cvar, loop=self.loop) |
| 188 | + |
| 189 | + async def test(): |
| 190 | + cvar.set('inner') |
| 191 | + port = tb.find_free_port() |
| 192 | + srv = await self.loop.create_server( |
| 193 | + lambda: proto, '127.0.0.1', port, |
| 194 | + ) |
| 195 | + |
| 196 | + s = socket.socket(socket.AF_INET) |
| 197 | + with s: |
| 198 | + s.setblocking(False) |
| 199 | + await self.loop.sock_connect(s, ('127.0.0.1', port)) |
| 200 | + |
| 201 | + try: |
| 202 | + inner = await proto.connection_made_fut |
| 203 | + self.assertEqual(inner, "inner") |
| 204 | + |
| 205 | + await proto.done |
| 206 | + self.assertEqual(proto.connection_lost_ctx, "inner") |
| 207 | + finally: |
| 208 | + srv.close() |
| 209 | + await srv.wait_closed() |
| 210 | + |
| 211 | + self.loop.run_until_complete(test()) |
| 212 | + |
| 213 | + def test_create_server_manual_connection_lost(self): |
| 214 | + if self.implementation == 'asyncio': |
| 215 | + raise unittest.SkipTest('this seems to be a bug in asyncio') |
| 216 | + |
| 217 | + cvar = contextvars.ContextVar('cvar', default='outer') |
| 218 | + proto = _Protocol(cvar, loop=self.loop) |
| 219 | + |
| 220 | + async def close(): |
| 221 | + cvar.set('closing') |
| 222 | + proto.transport.close() |
| 223 | + |
| 224 | + async def test(): |
| 225 | + cvar.set('inner') |
| 226 | + port = tb.find_free_port() |
| 227 | + srv = await self.loop.create_server( |
| 228 | + lambda: proto, '127.0.0.1', port, |
| 229 | + ) |
| 230 | + |
| 231 | + s = socket.socket(socket.AF_INET) |
| 232 | + s.setblocking(False) |
| 233 | + await self.loop.sock_connect(s, ('127.0.0.1', port)) |
| 234 | + |
| 235 | + try: |
| 236 | + inner = await proto.connection_made_fut |
| 237 | + self.assertEqual(inner, "inner") |
| 238 | + |
| 239 | + await asyncio.ensure_future(close()) |
| 240 | + |
| 241 | + await proto.done |
| 242 | + self.assertEqual(proto.connection_lost_ctx, "inner") |
| 243 | + finally: |
| 244 | + s.close() |
| 245 | + srv.close() |
| 246 | + await srv.wait_closed() |
| 247 | + |
| 248 | + self.loop.run_until_complete(test()) |
| 249 | + |
175 | 250 |
|
176 | 251 | class Test_UV_Context(_ContextBaseTests, tb.UVTestCase): |
177 | 252 | pass |
|
0 commit comments