Skip to content

Commit

Permalink
[py] Implement script module for BiDi
Browse files Browse the repository at this point in the history
  • Loading branch information
p0deje committed Jun 8, 2024
1 parent 40f684e commit 6d1358d
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ py/selenium/webdriver/remote/isDisplayed.js
py/docs/build/
py/build/
py/LICENSE
py/pytestdebug.log
selenium.egg-info/
third_party/java/jetty/jetty-repacked.jar
*.user
Expand Down
17 changes: 17 additions & 0 deletions py/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def pytest_addoption(parser):
dest="use_lan_ip",
help="Whether to start test server with lan ip instead of localhost",
)
parser.addoption(
"--bidi",
action="store",
dest="bidi",
metavar="BIDI",
default=True,
help="Whether to enable BiDi support",
)


def pytest_ignore_collect(path, config):
Expand Down Expand Up @@ -166,6 +174,7 @@ def get_options(driver_class, config):
browser_path = config.option.binary
browser_args = config.option.args
headless = bool(config.option.headless)
bidi = bool(config.option.bidi)
options = None

if browser_path or browser_args:
Expand All @@ -187,6 +196,14 @@ def get_options(driver_class, config):
options.add_argument("--headless=new")
if driver_class == "Firefox":
options.add_argument("-headless")

if bidi:
if not options:
options = getattr(webdriver, f"{driver_class}Options")()

if driver_class == "Chrome" or driver_class == "Edge" or driver_class == "Firefox":
options.web_socket_url = True

return options


Expand Down
111 changes: 111 additions & 0 deletions py/selenium/webdriver/common/bidi/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import typing
from dataclasses import dataclass

from .session import session_subscribe
from .session import session_unsubscribe


class Script:
def __init__(self, conn):
self.conn = conn
self.log_entry_subscribed = False

def add_console_message_handler(self, handler):
self._subscribe_to_log_entries()
return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler))

def add_javascript_error_handler(self, handler):
self._subscribe_to_log_entries()
return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("javascript", handler))

def remove_console_message_handler(self, id):
self.conn.remove_callback(LogEntryAdded, id)
self._unsubscribe_from_log_entries()

remove_javascript_error_handler = remove_console_message_handler

def _subscribe_to_log_entries(self):
if not self.log_entry_subscribed:
self.conn.execute(session_subscribe(LogEntryAdded.event_class))
self.log_entry_subscribed = True

def _unsubscribe_from_log_entries(self):
if self.log_entry_subscribed and LogEntryAdded.event_class not in self.conn.callbacks:
self.conn.execute(session_unsubscribe(LogEntryAdded.event_class))
self.log_entry_subscribed = False

def _handle_log_entry(self, type, handler):
def _handle_log_entry(log_entry):
if log_entry.type_ == type:
handler(log_entry)

return _handle_log_entry


class LogEntryAdded:
event_class = "log.entryAdded"

@classmethod
def from_json(cls, json):
print(json)
if json["type"] == "console":
return ConsoleLogEntry.from_json(json)
elif json["type"] == "javascript":
return JavaScriptLogEntry.from_json(json)


@dataclass
class ConsoleLogEntry:
level: str
text: str
timestamp: str
method: str
args: typing.List[dict]
type_: str

@classmethod
def from_json(cls, json):
return cls(
level=json["level"],
text=json["text"],
timestamp=json["timestamp"],
method=json["method"],
args=json["args"],
type_=json["type"],
)


@dataclass
class JavaScriptLogEntry:
level: str
text: str
timestamp: str
stacktrace: dict
type_: str

@classmethod
def from_json(cls, json):
return cls(
level=json["level"],
text=json["text"],
timestamp=json["timestamp"],
stacktrace=json["stackTrace"],
type_=json["type"],
)
42 changes: 42 additions & 0 deletions py/selenium/webdriver/common/bidi/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


def session_subscribe(*events, browsing_contexts=[]):
cmd_dict = {
"method": "session.subscribe",
"params": {
"events": events,
},
}
if browsing_contexts:
cmd_dict["params"]["browsingContexts"] = browsing_contexts
_ = yield cmd_dict
return None


def session_unsubscribe(*events, browsing_contexts=[]):
cmd_dict = {
"method": "session.unsubscribe",
"params": {
"events": events,
},
}
if browsing_contexts:
cmd_dict["params"]["browsingContexts"] = browsing_contexts
_ = yield cmd_dict
return None
30 changes: 29 additions & 1 deletion py/selenium/webdriver/common/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def __init__(self, name):
self.name = name

def __get__(self, obj, cls):
if self.name in ("acceptInsecureCerts", "strictFileInteractability", "setWindowRect", "se:downloadsEnabled"):
if self.name in (
"acceptInsecureCerts",
"strictFileInteractability",
"setWindowRect",
"se:downloadsEnabled",
"webSocketUrl",
):
return obj._caps.get(self.name, False)
return obj._caps.get(self.name)

Expand Down Expand Up @@ -361,6 +367,28 @@ class BaseOptions(metaclass=ABCMeta):
- `None`
"""

web_socket_url = _BaseOptionsDescriptor("webSocketUrl")
"""Gets and Sets WebSocket URL.
Usage
-----
- Get
- `self.web_socket_url`
- Set
- `self.web_socket_url` = `value`
Parameters
----------
`value`: `bool`
Returns
-------
- Get
- `bool`
- Set
- `None`
"""

def __init__(self) -> None:
super().__init__()
self._caps = self.default_capabilities
Expand Down
21 changes: 21 additions & 0 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from selenium.common.exceptions import NoSuchCookieException
from selenium.common.exceptions import NoSuchElementException
from selenium.common.exceptions import WebDriverException
from selenium.webdriver.common.bidi.script import Script
from selenium.webdriver.common.by import By
from selenium.webdriver.common.options import BaseOptions
from selenium.webdriver.common.print_page_options import PrintOptions
Expand Down Expand Up @@ -209,7 +210,9 @@ def __init__(
self._authenticator_id = None
self.start_client()
self.start_session(capabilities)

self._websocket_connection = None
self._script = None

def __repr__(self):
return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>'
Expand Down Expand Up @@ -1067,6 +1070,24 @@ async def bidi_connection(self):
async with conn.open_session(target_id) as session:
yield BidiConnection(session, cdp, devtools)

@property
def script(self):
if not self._websocket_connection:
self._start_bidi()

if not self._script:
self._script = Script(self._websocket_connection)

return self._script

def _start_bidi(self):
if self.caps.get("webSocketUrl"):
ws_url = self.caps.get("webSocketUrl")
else:
raise WebDriverException("Unable to find url to connect to from capabilities")

self._websocket_connection = WebSocketConnection(ws_url)

def _get_cdp_details(self):
import json

Expand Down
45 changes: 33 additions & 12 deletions py/selenium/webdriver/remote/websocket_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from websocket import WebSocketApp

logger = logging.getLogger("websocket")
logger = logging.getLogger(__name__)


class WebSocketConnection:
Expand All @@ -32,11 +32,11 @@ class WebSocketConnection:
_max_log_message_size = 9999

def __init__(self, url):
self.callbacks = {}
self.session_id = None
self.url = url

self._id = 0
self._callbacks = {}
self._messages = {}
self._started = False

Expand All @@ -57,17 +57,38 @@ def execute(self, command):
payload["sessionId"] = self.session_id

data = json.dumps(payload)
logger.debug(f"WebSocket -> {data}"[: self._max_log_message_size])
logger.debug(f"-> {data}"[: self._max_log_message_size])
self._ws.send(data)

self._wait_until(lambda: self._id in self._messages)
result = self._messages.pop(self._id)["result"]
return self._deserialize_result(result, command)
response = self._messages.pop(self._id)

def on(self, event, callback):
if event not in self._callbacks:
self._callbacks[event.event_class] = []
self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params)))
if "error" in response:
raise Exception(response["error"])
else:
result = response["result"]
return self._deserialize_result(result, command)

def add_callback(self, event, callback):
event_name = event.event_class
if event_name not in self.callbacks:
self.callbacks[event_name] = []

def _callback(params):
callback(event.from_json(params))

self.callbacks[event_name].append(_callback)
return id(_callback)

on = add_callback

def remove_callback(self, event, callback_id):
event_name = event.event_class
if event_name in self.callbacks:
for callback in self.callbacks[event_name]:
if id(callback) == callback_id:
self.callbacks[event_name].remove(callback)
return

def _serialize_command(self, command):
return next(command)
Expand All @@ -87,7 +108,7 @@ def on_message(ws, message):
self._process_message(message)

def on_error(ws, error):
logger.debug(f"WebSocket error: {error}")
logger.debug(f"error: {error}")
ws.close()

def run_socket():
Expand All @@ -102,14 +123,14 @@ def run_socket():

def _process_message(self, message):
message = json.loads(message)
logger.debug(f"WebSocket <- {message}"[: self._max_log_message_size])
logger.debug(f"<- {message}"[: self._max_log_message_size])

if "id" in message:
self._messages[message["id"]] = message

if "method" in message:
params = message["params"]
for callback in self._callbacks.get(message["method"], []):
for callback in self.callbacks.get(message["method"], []):
callback(params)

def _wait_until(self, condition):
Expand Down
Loading

0 comments on commit 6d1358d

Please sign in to comment.