Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hub Mode for inspector #2881

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Pass Node state
  • Loading branch information
ahopkins committed Jun 23, 2024
commit 44166ad45430341b2c814c1a4eb1b0796e431d23
25 changes: 22 additions & 3 deletions sanic/cli/inspector_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,39 @@ def do(self, action: str, **kwargs: Any) -> None:
sys.stdout.write(out + "\n")

def info(self) -> None:
out = sys.stdout.write
response = self.request("", "GET")
if self.raw or not response:
return
data = response["result"]
display = data.pop("info")
nodes = data.pop("nodes", {})
self._display_info(display)
self._display_workers(data["workers"], None if not nodes else "Hub")
if nodes:
for name, node in nodes.items():
# info = node.pop("info")
workers = node.pop("workers")
# self._display_info(info)
self._display_workers(workers, name)

def _display_info(self, display: Dict[str, Any]) -> None:
extra = display.pop("extra", {})
out = sys.stdout.write
display["packages"] = ", ".join(display["packages"])
MOTDTTY(get_logo(), self.base_url, display, extra).display(
version=False,
action="Inspecting",
out=out,
)
for name, info in data["workers"].items():

def _display_workers(
self, workers: Dict[str, Dict[str, Any]], node: Optional[str] = None
) -> None:
out = sys.stdout.write
for name, info in workers.items():
name = f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}"
if node:
name += f" {Colors.BOLD}{Colors.YELLOW}({node}){Colors.END}"
info = "\n".join(
f"\t{key}: {Colors.BLUE}{value}{Colors.END}"
for key, value in info.items()
Expand All @@ -78,7 +97,7 @@ def info(self) -> None:
+ indent(
"\n".join(
[
f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}",
name,
info,
]
),
Expand Down
170 changes: 143 additions & 27 deletions sanic/worker/inspector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
from __future__ import annotations

from asyncio import sleep, run as run_async
from dataclasses import dataclass
import random

from asyncio import (
Task,
get_running_loop,
sleep,
)
from asyncio import (
run as run_async,
)
from dataclasses import asdict, dataclass
from datetime import datetime
from inspect import isawaitable
from multiprocessing.connection import Connection
from os import environ
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Mapping, Tuple, Union
from signal import SIGTERM, SIGINT, signal

from websockets import ConnectionClosed, connect, connection
from signal import SIGINT, SIGTERM, signal
from string import ascii_lowercase, ascii_uppercase
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Mapping,
Optional,
Tuple,
Union,
)

from websockets import WebSocketException, connection
from websockets.legacy.client import Connect, WebSocketClientProtocol

from sanic.exceptions import Unauthorized
from sanic.helpers import Default, _default
Expand All @@ -20,52 +40,129 @@
from sanic.server.websockets.impl import WebsocketImplProtocol


try:
from ujson import dumps as dump_json
from ujson import loads as load_json
except ImportError:
from json import dumps as dump_json
from json import loads as load_json

if TYPE_CHECKING:
from sanic import Sanic


@dataclass
class NodeState:
...
info: Dict[str, Any]
workers: Dict[str, Any]


@dataclass
class HubState:
nodes: Dict[str, NodeState]


class HubConnection(Connect):
MAX_RETRIES = 6
BACKOFF_MAX = 15

async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
backoff_delay = self.BACKOFF_MIN
failures = 0
while True:
if failures >= self.MAX_RETRIES:
raise RuntimeError(
"Could not connect to bridge "
f"after {self.MAX_RETRIES} retries"
)
try:
async with self as protocol:
if failures > 0:
self.logger.info(
"! connect succeeded after %d failures", failures
)
failures = 0
yield protocol
except Exception:
# Add a random initial delay between 0 and 5 seconds.
# See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
if backoff_delay == self.BACKOFF_MIN:
initial_delay = random.random() * self.BACKOFF_INITIAL
self.logger.info(
"! connect failed; reconnecting in %.1f seconds",
initial_delay,
)
self.logger.debug("Exception", exc_info=True)
await sleep(initial_delay)
else:
self.logger.info(
"! connect failed again; retrying in %d seconds",
int(backoff_delay),
)
self.logger.debug("Exception", exc_info=True)
await sleep(int(backoff_delay))
# Increase delay with truncated exponential backoff.
backoff_delay = backoff_delay * self.BACKOFF_FACTOR
backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
failures += 1
continue
else:
# Connection succeeded - reset backoff delay
backoff_delay = self.BACKOFF_MIN


class NodeClient:
def __init__(self, hub_host: str, hub_port: int) -> None:
self.hub_host = hub_host
self.hub_port = hub_port
self._run = True
self._heartbeat_task: Optional[Task] = None
self._command_task: Optional[Task] = None

async def run(self, state_getter) -> None:
loop = get_running_loop()
try:
async for ws in connect(f"ws://{self.hub_host}:{self.hub_port}/hub"):
# async with connect(f"ws://{self.hub_host}:{self.hub_port}/hub") as ws:
async for ws in HubConnection(
f"ws://{self.hub_host}:{self.hub_port}/hub"
):
try:
close = await self._run_node(ws, state_getter)
if close:
...
except ConnectionClosed:
...
except BaseException:
...
self._cancel_tasks()
self._heartbeat_task = loop.create_task(
self._heartbeat(ws, state_getter)
)
self._command_task = loop.create_task(self._command(ws))
while self._run:
await sleep(1)
except WebSocketException:
logger.debug("Connection to hub dropped")
finally:
if not self._run:
break
finally:
print("Node out")

def _setup_ws_client(self, hub_host: str, hub_port: int) -> connection:
return connect(f"ws://{hub_host}:{hub_port}/hub")
self._cancel_tasks()
logger.debug("Node client shutting down")

async def _run_node(self, ws: connection, state_getter) -> None:
async def _heartbeat(self, ws: connection, state_getter) -> None:
while self._run:
await ws.send(str(state_getter()))
await ws.send(dump_json(state_getter()))
await sleep(3)
return True

async def _command(self, ws: connection) -> None:
while self._run:
message = await ws.recv()
logger.info("Node received message: %s", message)

def _cancel_tasks(self) -> None:
if self._heartbeat_task:
self._heartbeat_task.cancel()
self._heartbeat_task = None
if self._command_task:
self._command_task.cancel()
self._command_task = None

def close(self, *args):
self._run = False
self._cancel_tasks()


class Inspector:
Expand Down Expand Up @@ -148,7 +245,6 @@ def _detect_modes(
hub_host: str,
hub_port: int,
) -> Tuple[bool, bool]:
print(hub_mode, host, port, hub_host, hub_port)
if hub_host == host and hub_port == port:
if not hub_mode:
raise ValueError(
Expand Down Expand Up @@ -206,6 +302,11 @@ async def _respond(self, request: Request, output: Any):
def _state_to_json(self) -> Dict[str, Any]:
output = {"info": self.app_info}
output["workers"] = self._make_safe(dict(self.worker_state))
if self.hub_mode:
output["nodes"] = {
ident: self._make_safe(asdict(node))
for ident, node in self.app.ctx.hub_state.nodes.items()
}
return output

@staticmethod
Expand Down Expand Up @@ -253,21 +354,29 @@ def shutdown(self) -> None:
self._publisher.send(message)

def _setup_hub(self, app: Sanic) -> None:
logger.info(
f"Sanic Inspector running in hub mode on {self.host}:{self.port}"
)
app.ctx.hub_state = HubState(nodes={})

@staticmethod
async def _hub(
self,
request: Request,
websocket: WebsocketImplProtocol,
) -> None:
hub_state = request.app.ctx.hub_state
hub_state.nodes[request.id] = NodeState()
ident = self._generate_ident()
hub_state.nodes[ident] = NodeState({}, {})
while True:
message = await websocket.recv()
if message == "ping":
await websocket.send("pong")
elif not message:
break
else:
logger.info("Hub received message: %s", message)
raw = load_json(message)
node_state = NodeState(**raw)
hub_state.nodes[ident] = node_state

async def _run_node(self) -> None:
client = NodeClient(self.hub_host, self.hub_port)
Expand All @@ -278,4 +387,11 @@ def signal_close(*args, **kwargs):
signal(SIGTERM, signal_close)
signal(SIGINT, signal_close)

logger.info(
f"Sanic Inspector running in node mode on {self.host}:{self.port}"
)
await client.run(self._state_to_json)

def _generate_ident(self, length: int = 8) -> str:
base = ascii_lowercase + ascii_uppercase
return "".join(random.choices(base, k=length))
Loading