#pike __REAL_VERSION__ |
|
#ifdef WEBSOCKET_DEBUG |
# define WS_WERR(level, x...) do { if (WEBSOCKET_DEBUG >= level) { werror("%O: ", this); werror(x); } } while(0) |
#else |
# define WS_WERR(level, x...) |
#endif |
|
constant websocket_id = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
constant websocket_version = 13; |
|
|
|
|
constant MASK = _Roxen.websocket_mask; |
|
private constant agent = sprintf("Pike/%d.%d", __MAJOR__, __MINOR__); |
|
|
typedef function(Frame, mixed:void) message_callback; |
|
function curry_back(function cb, mixed ... extra_args) { |
void f(mixed ... args) { |
cb(@args, @extra_args); |
}; |
|
return f; |
} |
|
|
enum FRAME { |
|
|
FRAME_CONTINUATION = 0x0, |
|
|
FRAME_TEXT, |
|
|
FRAME_BINARY, |
|
|
FRAME_CLOSE = 0x8, |
|
|
FRAME_PING, |
|
|
FRAME_PONG, |
}; |
|
|
|
enum CLOSE_STATUS { |
|
|
CLOSE_NORMAL = 1000, |
|
|
CLOSE_GONE_AWAY, |
|
|
CLOSE_ERROR, |
|
|
CLOSE_BAD_TYPE, |
|
|
CLOSE_NONE = 1005, |
|
|
CLOSE_BAD_DATA = 1007, |
|
|
CLOSE_POLICY, |
|
|
CLOSE_OVERFLOW, |
|
|
CLOSE_EXTENSION = 1010, |
|
|
CLOSE_UNEXPECTED, |
}; |
|
|
enum RSV { |
|
|
RSV1 = 0x40, |
|
|
RSV2 = 0x20, |
|
|
RSV3 = 0x10 |
|
}; |
|
|
enum COMPRESSION { |
|
|
HEURISTICS_COMPRESS = 0, |
|
|
OVERRIDE_COMPRESS, |
|
}; |
|
int(0..1) is_valid_close(int close_status) { |
switch (close_status) { |
case CLOSE_NORMAL: |
case CLOSE_GONE_AWAY: |
case CLOSE_ERROR: |
case CLOSE_BAD_TYPE: |
|
case CLOSE_BAD_DATA: |
case CLOSE_POLICY: |
case CLOSE_OVERFLOW: |
case CLOSE_EXTENSION: |
case CLOSE_UNEXPECTED: |
case 3000 .. 3999: |
case 4000 .. 4999: |
return 1; |
} |
return 0; |
} |
|
#define FOO(x) if (op == x) return #x |
string describe_opcode(FRAME op) { |
FOO(FRAME_CONTINUATION); |
FOO(FRAME_TEXT); |
FOO(FRAME_BINARY); |
FOO(FRAME_CLOSE); |
FOO(FRAME_PING); |
FOO(FRAME_PONG); |
return sprintf("0x%x", op); |
} |
|
mapping(string:mapping) parse_websocket_extensions(string header) { |
mapping(string:mapping) retval = ([]); |
if (!header) return retval; |
|
|
array tmp = array_sscanf(header, |
"%*[ \t\r\n]%{%{%[^ \t\r\n=;,]%*[= \t\r\n]%[^;,]%*[ \t\r\n;]%}" |
"%*[ \t\r\n,]%}")[0]; |
foreach (tmp; int i; array v) { |
mapping m = ([]); |
array d; |
v = v[0]; |
retval[v[0][0]] = m; |
v = v[1..]; |
foreach (v;; d) { |
string sv = String.trim_whites(d[1]); |
if (sizeof(sv) && sv[0] == '"') |
sv = sv[1..<1]; |
int|float|string tv; |
if ((string)(tv=(int)sv)!=sv && (string)(tv=(float)sv)!=sv) |
tv = sv; |
m[d[0]] = tv; |
} |
} |
return retval; |
} |
|
string encode_websocket_extensions(mapping(string:mapping) ext) { |
array ev = ({}); |
foreach (ext; string name; mapping ext) { |
array res = ({name}); |
foreach (ext; string pname; int|float|string pval) { |
|
|
|
|
if (stringp(pval) && has_value(pval, " ")) |
pval = "\"" + pval + "\""; |
pval = (string)pval; |
if (sizeof(pval)) |
pval = "="+pval; |
res += ({pname+pval}); |
} |
ev += ({res * ";"}); |
} |
return ev * ","; |
} |
|
|
protected Frame low_parse(Connection con, Stdio.Buffer buf) { |
int opcode, len; |
int(0..1) masked; |
string mask, data; |
|
Stdio.Buffer.RewindKey rewind_key = buf->rewind_on_error(); |
opcode = buf->read_int8(); |
len = buf->read_int8(); |
|
masked = len >> 7; |
len &= 127; |
|
if (len == 126) { |
len = buf->read_int16(); |
} else if (len == 127) { |
len = buf->read_int(8); |
} |
|
if (masked) { |
mask = buf->read(4); |
} |
|
data = buf->read(len); |
rewind_key->release(); |
|
Frame f = Frame(opcode & 15); |
f->fin = opcode >> 7; |
f->mask = mask; |
f->rsv = opcode; |
|
if (masked) { |
data = MASK(data, mask); |
} |
|
f->data = data; |
|
return f; |
} |
|
|
Frame parse(Connection con, Stdio.Buffer in) { |
|
mixed err = catch { |
return low_parse(con, in); |
}; |
|
if (!objectp(err) || !err->buffer_error) { |
|
throw(err); |
} |
|
return UNDEFINED; |
} |
|
class Frame { |
|
FRAME opcode; |
|
|
|
int(0..1) fin = 1; |
|
|
|
|
int rsv; |
|
|
mapping(string:mixed) options; |
|
string mask; |
|
|
|
string data = ""; |
|
|
|
|
|
|
|
|
protected void create(FRAME opcode, void|string|CLOSE_STATUS data, void|int(0..1) fin) { |
this::opcode = opcode; |
if (data) switch (opcode) { |
case FRAME_TEXT: |
this::data = string_to_utf8(data); |
this::fin = undefinedp(fin) || fin; |
break; |
case FRAME_PONG: |
case FRAME_PING: |
if (!stringp(data) || String.width(data) != 8) |
error("Bad argument. Expected string(8bit).\n"); |
this::data = data; |
break; |
case FRAME_BINARY: |
if (!stringp(data) || String.width(data) != 8) |
error("Bad argument. Expected string(8bit).\n"); |
this::fin = undefinedp(fin) || fin; |
this::data = data; |
break; |
case FRAME_CLOSE: |
if (intp(data)) { |
this::data = sprintf("%2c", data); |
} else if (stringp(data) && String.width(data) == 8) { |
this::data = data; |
} else error("Bad argument. Expected CLOSE_STATUS or string(8bit).\n"); |
break; |
case FRAME_CONTINUATION: |
if (!stringp(data)) |
error("Bad argument. Expected string.\n"); |
if (String.width(data) != 8) |
error("%s frames cannot hold widestring data.\n", |
describe_opcode(opcode)); |
this::data = data; |
this::fin = undefinedp(fin) || fin; |
break; |
} |
} |
|
protected string _sprintf(int type) { |
return type=='O' && sprintf("%O(%s, fin: %d, rsv: %d, %d bytes)", |
this_program, |
describe_opcode(opcode), fin, rsv & (RSV1|RSV2|RSV3), |
sizeof(data)); |
} |
|
private string _text; |
|
|
|
|
string `text() { |
if (opcode != FRAME_TEXT) error("Not a text frame.\n"); |
if (!_text) _text = utf8_to_string(data); |
return _text; |
} |
|
string `text=(string s) { |
if (opcode != FRAME_TEXT) error("Not a text frame.\n"); |
_text = s; |
data = string_to_utf8(s); |
return s; |
} |
|
|
|
|
CLOSE_STATUS `reason() { |
int i; |
if (opcode != FRAME_CLOSE) |
error("This is not a close frame.\n"); |
if (!sizeof(data)) return CLOSE_NORMAL; |
if (sscanf(data, "%2c", i) == 1) return i; |
return CLOSE_ERROR; |
} |
|
CLOSE_STATUS `reason=(CLOSE_STATUS r) { |
if (opcode != FRAME_CLOSE) |
error("This is not a close frame.\n"); |
data = sprintf("%2c", r); |
return r; |
} |
|
|
|
string `close_reason() { |
if (opcode != FRAME_CLOSE) |
error("This is not a close frame.\n"); |
if (sizeof(data) <= 2) return 0; |
return utf8_to_string(data[2..]); |
} |
|
|
void encode(Stdio.Buffer buf) { |
buf->add_int8(fin << 7 | rsv | opcode); |
|
if (sizeof(data) > 0xffff) { |
buf->add_int8(!!mask << 7 | 127); |
buf->add_int(sizeof(data), 8); |
} else if (sizeof(data) > 125) { |
buf->add_int8(!!mask << 7 | 126); |
buf->add_int16(sizeof(data)); |
} else buf->add_int8(!!mask << 7 | sizeof(data)); |
|
if (mask) { |
buf->add(mask, MASK(data, mask)); |
} else { |
buf->add(data); |
} |
} |
|
protected string cast(string to) |
{ |
if (to == "string") { |
Stdio.Buffer buf = Stdio.Buffer(); |
encode(buf); |
return buf->read(); |
} |
return UNDEFINED; |
} |
} |
|
|
class Connection { |
|
Stdio.File|SSL.File stream; |
|
Stdio.Buffer out = Stdio.Buffer(); |
Stdio.Buffer in = Stdio.Buffer()->set_error_mode(1); |
|
protected int buffer_mode = 0; |
|
|
protected Standards.URI endpoint; |
|
|
protected mapping(string:string) extra_headers; |
|
protected mixed id; |
|
protected array(object) extensions; |
|
|
int(0..1) masking; |
|
|
enum STATE { |
|
|
CONNECTING = 0x0, |
|
|
OPEN, |
|
|
CLOSING, |
|
|
CLOSED, |
}; |
|
|
STATE state = CONNECTING; |
|
protected CLOSE_STATUS close_reason; |
|
protected string _sprintf(int type) { |
return sprintf("%O(%d, %O, %s, %s)", this_program, state, stream, |
endpoint?(string)endpoint:"server", |
buffer_mode?"buffer mode": "callback mode only"); |
} |
|
|
|
void set_id(mixed id) { |
this::id = id; |
} |
|
|
protected void create(Stdio.File|SSL.File f, void|int|array(object) extensions) { |
if (arrayp(extensions)) this_program::extensions = extensions; |
stream = f; |
if (f->set_buffer_mode) { |
f->set_buffer_mode(in, out); |
buffer_mode = 1; |
} |
f->set_nonblocking(websocket_in, websocket_write, websocket_closed); |
|
state = OPEN; |
if (onopen) onopen(id || this); |
WS_WERR(2, "opened\n"); |
} |
|
|
protected variant void create() { |
masking = 1; |
state = CLOSED; |
} |
|
protected array(mapping) low_connect(Standards.URI endpoint, |
mapping(string:string) extra_headers, |
void|array extensions) |
{ |
string host = endpoint->host; |
|
if (endpoint->port) host += ":" + endpoint->port; |
|
mapping headers = ([ |
"Host" : host, |
"Connection" : "Upgrade", |
"User-Agent" : agent, |
"Accept": "*/*", |
"Upgrade" : "websocket", |
"Sec-WebSocket-Key" : |
MIME.encode_base64(Crypto.Random.random_string(16), 1), |
"Sec-WebSocket-Version": (string)websocket_version, |
]); |
|
foreach(extra_headers; string idx; string val) { |
headers[idx] = val; |
} |
|
expected_accept = |
MIME.encode_base64(Crypto.SHA1.hash(headers["Sec-WebSocket-Key"] + |
websocket_id)); |
|
mapping rext; |
|
if (arrayp(extensions)) { |
rext = ([]); |
|
foreach (extensions; int i; extension_factory f) { |
mixed o = f(1, 0, rext); |
if (objectp(o)) extensions[i] = o; |
} |
|
if (sizeof(rext)) |
headers["Sec-WebSocket-Extensions"] = encode_websocket_extensions(rext); |
} |
|
return ({ headers, rext }); |
} |
|
|
|
|
|
|
int connect(string|Standards.URI endpoint, void|mapping(string:string) extra_headers, |
void|array extensions) { |
if (stringp(endpoint)) endpoint = Standards.URI(endpoint); |
this_program::endpoint = endpoint; |
this_program::extra_headers = extra_headers = extra_headers || ([]); |
|
if (endpoint->path == "") endpoint->path = "/"; |
|
Stdio.File f = Stdio.File(); |
state = CONNECTING; |
|
int port; |
|
if (endpoint->scheme == "ws") { |
port = endpoint->port || 80; |
} else if (endpoint->scheme == "wss") { |
port = endpoint->port || 443; |
} else error("Not a WebSocket URL.\n"); |
|
int res = f->connect(endpoint->host, port); |
|
if (!res) { |
websocket_closed(); |
return 0; |
} |
|
if (endpoint->scheme == "wss") { |
|
|
SSL.Context ctx = SSL.Context(); |
stream = SSL.File(f, ctx); |
object ssl_session = stream->connect(endpoint->host,0); |
if (!ssl_session) { |
WS_WERR(1, "Handshake failed\n"); |
websocket_closed(); |
return 0; |
} |
} else { |
stream = f; |
} |
|
buffer_mode = 0; |
|
if (arrayp(extensions)) { |
|
extensions = extensions + ({ }); |
} |
|
[mapping headers, mapping rext] = |
low_connect(endpoint, extra_headers, extensions); |
|
stream->set_nonblocking(curry_back(http_read, _Roxen.HeaderParser(), extensions, rext), |
websocket_write, websocket_closed); |
|
|
|
send_raw("GET ", endpoint->get_http_path_query(), " HTTP/1.1\r\n"); |
foreach(headers; string h; string v) { |
send_raw(h, ": ", v, "\r\n"); |
} |
send_raw("\r\n"); |
return res; |
} |
|
|
function(mixed,void|mixed:void) onopen; |
|
|
function(Frame, mixed:void) onmessage; |
|
|
|
|
|
|
|
function(CLOSE_STATUS, mixed:void) onclose; |
|
|
|
|
int `bufferdAmount() { |
return sizeof(out); |
} |
|
void send_raw(string(8bit) ... s) { |
WS_WERR(3, "out:\n----\n%s\n----\n", s*"\n----\n"); |
out->add(@s); |
stream->write(""); |
} |
|
protected string expected_accept; |
|
|
|
protected void http_read(mixed _id, string data, |
object hp, array(extension_factory) extensions, mapping rext) { |
|
if (state != CONNECTING) { |
websocket_closed(); |
return; |
} |
|
array tmp = hp->feed(data); |
|
if (tmp) { |
int major, minor; |
int status; |
mapping headers = tmp[2]; |
string status_desc; |
|
WS_WERR(2, "http_read: header done. Parsed: %O\n", tmp); |
|
|
|
|
|
|
|
|
|
|
|
if (sscanf(tmp[1], "HTTP/%d.%d %d %s", |
major, minor, status, status_desc) != 4) { |
websocket_closed(); |
return; |
} |
|
if (status != 101) { |
WS_WERR(1, "http_read: Bad http status code: %d.\n", status); |
websocket_closed(); |
return; |
} |
|
|
|
|
|
|
|
if (lower_case(headers["upgrade"] || "") != "websocket") { |
WS_WERR(1, "http_read: No upgrade header.\n"); |
websocket_closed(); |
return; |
} |
|
|
|
|
|
|
|
if (lower_case(headers["connection"] || "") != "upgrade") { |
WS_WERR(1, "http_read: No connection header with upgrade.\n"); |
websocket_closed(); |
return; |
} |
|
|
|
|
|
|
|
|
|
|
if (headers["sec-websocket-accept"] != expected_accept) { |
WS_WERR(1, "http_read: Missing or invalid Sec-WebSocket-Accept.\n"); |
websocket_closed(); |
return; |
} |
|
|
|
|
|
|
|
if (!has_value((array(int))((headers["sec-websocket-version"] || |
(string)websocket_version)/","), |
websocket_version)) { |
WS_WERR(1, "http_read: Unsupported Sec-WebSocket-Version: %O.\n", |
headers["sec-websocket-version"]); |
websocket_closed(); |
return; |
} |
|
if (arrayp(extensions)) { |
mapping ext = parse_websocket_extensions(headers["sec-websocket-extensions"]); |
array tmp = ({ }); |
|
|
foreach (extensions; int i; object|extension_factory f) { |
if (!objectp(f)) |
extensions[i] = f(1, ext, rext); |
} |
|
extensions = filter(extensions, objectp); |
|
if (sizeof(extensions)) this_program::extensions = extensions; |
} |
|
if (endpoint->scheme != "wss") { |
stream->set_buffer_mode(in, out); |
buffer_mode = 1; |
} |
|
stream->set_nonblocking(websocket_in, websocket_write, websocket_closed); |
|
state = OPEN; |
if (onopen) onopen(id || this, headers); |
WS_WERR(2, "opened\n"); |
|
if (sizeof(tmp[0])) { |
in->add(tmp[0]); |
websocket_in(_id, in); |
} |
|
} |
} |
|
|
protected void websocket_write() { |
if (buffer_mode) return; |
if (sizeof(out)) { |
int n = out->output_to(stream); |
if (n < 0) { |
int e = errno(); |
if (e) { |
websocket_closed(); |
} |
} |
} |
} |
|
|
protected void websocket_in(mixed _id, string data) { |
in->add(data); |
websocket_in(_id, in); |
} |
|
|
protected variant void websocket_in(mixed _id, Stdio.Buffer in) { |
|
|
|
|
|
FRAMES: while (Frame frame = parse(this, in)) { |
if (state == CLOSED) return; |
|
if (extensions) foreach (extensions;; object e) { |
if (e->receive) { |
frame = e->receive(frame, this); |
if (!frame) continue FRAMES; |
} |
} |
|
int opcode = frame->opcode; |
WS_WERR(2, "%O in %O\n", this, frame); |
|
switch (opcode) { |
case FRAME_PING: |
send(Frame(FRAME_PONG, frame->data)); |
continue; |
case FRAME_CLOSE: |
if (!is_valid_close(frame->reason)) { |
WS_WERR(1, "Received invalid close reason: %d\n", frame->reason); |
fail(); |
return; |
} |
if (state == OPEN) { |
if (catch(frame->close_reason)) { |
WS_WERR(1, "Non utf8 text in close frame.\n"); |
fail(CLOSE_BAD_DATA); |
return; |
} |
close(frame->reason); |
|
|
|
close_event(frame->reason); |
} else if (state == CLOSING) { |
destruct(stream); |
|
|
close_event(close_reason); |
} |
return; |
} |
|
if (onmessage) onmessage(frame, id || this); |
} |
} |
|
protected void close_event(CLOSE_STATUS reason) { |
state = CLOSED; |
if (onclose) { |
onclose(reason, id || this); |
onclose = 0; |
} |
} |
|
protected void websocket_closed() { |
if (stream) destruct(stream); |
|
close_event(0); |
WS_WERR(2, "closed\n"); |
} |
|
|
void ping(void|string s) { |
send(Frame(FRAME_PING, s)); |
} |
|
|
|
|
void close(void|string(8bit)|CLOSE_STATUS reason, void|string msg) { |
if (!reason) reason = CLOSE_NORMAL; |
if (msg) { |
send(Frame(FRAME_CLOSE, sprintf("%2c%s", reason, string_to_utf8(msg)))); |
} else send(Frame(FRAME_CLOSE, reason)); |
} |
|
|
|
void fail(void|CLOSE_STATUS reason) { |
if (!reason) reason = CLOSE_ERROR; |
close(reason); |
close_event(reason); |
destruct(stream); |
} |
|
|
void send(Frame frame) { |
int opcode = frame->opcode; |
if (state != OPEN) |
error("WebSocket connection is not open: %O.\n", this); |
if (extensions) foreach (extensions;; object e) { |
if (e->send) { |
frame = e->send(frame, this); |
if (!frame) return; |
} |
} |
|
|
if (masking) frame->mask = random_string(4); |
WS_WERR(2, "sending %O\n", frame); |
frame->encode(out); |
stream->write(""); |
if (opcode == FRAME_CLOSE) { |
state = CLOSING; |
close_reason = frame->reason; |
stream->close("w"); |
} |
} |
|
|
void send_text(string s) { |
send(Frame(FRAME_TEXT, s)); |
} |
|
void send_continuation(string(8bit) data, void|int(0..1) fin) { |
send(Frame(FRAME_CONTINUATION, data, fin)); |
} |
|
|
void send_binary(string(0..255) s) { |
send(Frame(FRAME_BINARY, s)); |
} |
|
} |
|
|
class Request(function(array(string), Request:void) cb) { |
inherit Protocols.HTTP.Server.Request; |
|
protected int parse_variables() { |
WS_WERR(2, "parse_variables: headers: %O\n", request_headers); |
WS_WERR(2, "parse_variables: query: %O\n", query); |
WS_WERR(2, "parse_variables: variables: %O\n", variables); |
|
|
|
|
|
|
if ((request_type != "GET") || !has_prefix(protocol, "HTTP/") || |
(protocol[sizeof("HTTP/")..] < "1.1")) { |
WS_WERR(1, "parse_variables: Not a websocket request (2).\n"); |
return ::parse_variables(); |
} |
|
|
|
|
|
if (!has_value(lower_case(request_headers["upgrade"] || ""), |
"websocket")) { |
WS_WERR(1, "parse_variables: Not a websocket request (5).\n"); |
return ::parse_variables(); |
} |
|
|
|
|
|
if (!has_value(lower_case(request_headers["connection"] || ""), |
"upgrade")) { |
WS_WERR(1, "parse_variables: Not a websocket request (6).\n"); |
return ::parse_variables(); |
} |
|
|
|
|
|
|
|
string raw_key; |
catch { |
raw_key = MIME.decode_base64(request_headers["sec-websocket-key"]); |
}; |
if (!raw_key || (sizeof(raw_key) != 16)) { |
WS_WERR(1, "parse_variables: Not a websocket request (7).\n"); |
return ::parse_variables(); |
} |
|
|
|
|
|
|
if (request_headers["sec-websocket-version"] != |
(string)websocket_version) { |
WS_WERR(1, "parse_variables: Not a websocket request (9).\n"); |
return ::parse_variables(); |
} |
|
if (query!="") |
.HTTP.Server.http_decode_urlencoded_query(query,variables); |
flatten_headers(); |
string proto = request_headers["sec-websocket-protocol"]; |
array(string) protocols = proto ? proto / ", " : ({}); |
WS_WERR(1, "websocket request: %O\n", protocols); |
if (cb) { |
cb(protocols, this); |
} |
return 0; |
} |
|
array(mapping(string:string)|array) |
low_websocket_accept(string|void protocol, |
array(extension_factory)|void extensions, |
mapping(string:string)|void extra_headers) |
{ |
string s = request_headers["sec-websocket-key"] + websocket_id; |
mapping heads = ([ |
"Upgrade" : "websocket", |
"Connection" : "Upgrade", |
"Sec-WebSocket-Accept" : MIME.encode_base64(Crypto.SHA1.hash(s)), |
"Sec-WebSocket-Version" : (string)websocket_version, |
"Server" : agent, |
]); |
|
if (extra_headers) heads += extra_headers; |
|
array _extensions; |
mapping rext = ([]); |
|
if (extensions && sizeof(extensions)) { |
mapping ext = parse_websocket_extensions(request_headers["sec-websocket-extensions"]); |
array tmp = ({ }); |
|
foreach (extensions;; extension_factory f) { |
object e = f(0, ext, rext); |
if (e) tmp += ({ e }); |
|
} |
|
if (sizeof(tmp)) _extensions = tmp; |
if (sizeof(rext)) |
heads["Sec-WebSocket-Extensions"] = encode_websocket_extensions(rext); |
} |
|
if (protocol) heads["Sec-Websocket-Protocol"] = protocol; |
|
return ({ heads, _extensions }); |
} |
|
|
|
|
|
|
|
|
|
|
|
Connection websocket_accept(string protocol, void|array(extension_factory) extensions, |
void|mapping extra_headers) { |
[mapping heads, array _extensions] = |
low_websocket_accept(protocol, extensions, extra_headers); |
|
Connection ws = Connection(my_fd, _extensions); |
WS_WERR(2, "Using extensions: %O\n", _extensions); |
my_fd = 0; |
|
ws->send_raw("HTTP/1.1 101 SwitchingProtocols\r\n"); |
|
foreach (heads; string k; string v) { |
ws->send_raw(sprintf("%s: %s\r\n", k, v)); |
} |
|
ws->send_raw("\r\n"); |
|
finish(0); |
|
return ws; |
} |
} |
|
|
|
|
|
|
|
|
|
|
class Port { |
inherit Protocols.HTTP.Server.Port; |
|
|
protected void create(function(Protocols.HTTP.Server.Request:void) http_cb, |
function(array(string), Request:void)|void ws_cb, |
void|int portno, void|string interface) { |
|
::create(http_cb, portno, interface); |
|
if (ws_cb) |
request_program = Function.curry(Request)(ws_cb); |
} |
} |
|
|
|
|
class SSLPort { |
inherit Protocols.HTTP.Server.SSLPort; |
|
protected void create(function(Protocols.HTTP.Server.Request:void) http_cb, |
function(array(string), Request:void)|void ws_cb, |
void|int portno, void|string interface, |
void|string key, void|string|array certificate) { |
|
::create(http_cb, portno, interface, key, certificate); |
|
if (ws_cb) |
request_program = Function.curry(Request)(ws_cb); |
} |
} |
|
|
|
typedef function(int(0..1),mapping,mapping:object)|program extension_factory; |
|
|
class Extension { |
|
|
Frame receive(Frame, Connection con); |
|
|
Frame send(Frame, Connection con); |
} |
|
|
class defragment { |
inherit Extension; |
|
private Frame fragment; |
|
Frame receive(Frame frame, Connection con) { |
int opcode = frame->opcode; |
int(0..1) fin = frame->fin; |
|
if (opcode == FRAME_CONTINUATION) { |
if (!fragment) { |
con->fail(); |
WS_WERR(1, "Bad continuation.\n"); |
return 0; |
} |
fragment->data += frame->data; |
|
if (fin) { |
frame = fragment; |
frame->fin = 1; |
fragment = 0; |
} else return 0; |
} else if (!fin) { |
if (opcode != FRAME_TEXT && opcode != FRAME_BINARY) { |
WS_WERR(1, "Received fragmented control frame. closing connection.\n"); |
con->fail(); |
return 0; |
} |
if (fragment) { |
con->fail(); |
WS_WERR(1, "Unfinished fragmented message.\n"); |
return 0; |
} |
fragment = frame; |
return 0; |
} else if (fragment && !(opcode & 0x8)) { |
con->fail(); |
WS_WERR(1, "Non control frame during fragmented traffic.\n"); |
return 0; |
} |
|
return frame; |
} |
} |
|
#if constant(Gz.deflate) |
class _permessagedeflate { |
inherit defragment; |
|
protected Gz.inflate uncompress; |
protected Gz.deflate compress; |
|
mapping options; |
|
void create(mapping options) { |
this_program::options = options; |
} |
|
private void try_compress(Frame frame) { |
return; |
mapping(string:mixed) opts = options; |
if (sizeof(frame->data) >= |
(opts->compressionNoContextTakeover |
? opts->compressionThresholdNoContext |
: opts->compressionThreshold)) { |
if (!compress) |
compress = Gz.deflate(-opts->compressionLevel, |
opts->compressionStrategy, |
opts->compressionWindowSize); |
int wsize = opts->compressionWindowSize |
? 1<<opts->compressionWindowSize : 1<<15; |
if (opts->compressionNoContextTakeover) { |
string s |
= compress->deflate(frame->data, Gz.SYNC_FLUSH)[..<4]; |
if (sizeof(s) < sizeof(frame->data)) { |
frame->data = s; |
frame->rsv |= RSV1; |
} |
compress = 0; |
} else { |
if (opts->compressionHeuristics == OVERRIDE_COMPRESS |
|| frame->opcode == FRAME_TEXT) { |
|
frame->data |
= compress->deflate(frame->data, Gz.SYNC_FLUSH)[..<4]; |
frame->rsv |= RSV1; |
} else if (4*sizeof(frame->data) <= wsize) { |
|
|
|
|
Gz.inflate save = compress->clone(); |
string s |
= compress->deflate(frame->data, Gz.SYNC_FLUSH); |
if (sizeof(s) < sizeof(frame->data)) { |
frame->data = s[..<4]; |
frame->rsv |= RSV1; |
} else |
compress = save; |
} else { |
|
|
|
Gz.inflate ctest = compress->clone(); |
string sold = frame->data[..1023]; |
string s = ctest->deflate(sold, Gz.PARTIAL_FLUSH); |
if (sizeof(s) + 64 < sizeof(sold)) { |
frame->data = compress->deflate(frame->data, |
Gz.SYNC_FLUSH)[..<4]; |
frame->rsv |= RSV1; |
} |
} |
} |
} |
} |
|
Frame send(Frame frame, Connection con) { |
int opcode = frame->opcode; |
|
if (opcode == FRAME_TEXT || opcode == FRAME_BINARY) try_compress(frame); |
|
return frame; |
} |
|
Frame receive(Frame frame, Connection con) { |
frame = ::receive(frame, con); |
|
if (!frame) return 0; |
|
int opcode = frame->opcode; |
int rsv1 = frame->rsv & RSV1; |
|
if (rsv1) { |
if (opcode != FRAME_BINARY && opcode != FRAME_TEXT) { |
con->fail(CLOSE_EXTENSION); |
WS_WERR(1, "Received compressed non-data frame.\n"); |
return 0; |
} |
|
if (!options->compressionLevel) { |
con->fail(CLOSE_EXTENSION); |
WS_WERR(1, "Unexpected compressed frame.\n"); |
return 0; |
} |
|
frame->rsv &= ~RSV1; |
|
if (!uncompress) uncompress = Gz.inflate(-options->decompressionWindowSize); |
if (mixed err = catch(frame->data = uncompress->inflate(frame->data + "\0\0\377\377"))) { |
con->fail(CLOSE_EXTENSION); |
master()->handle_error(err); |
return 0; |
} |
} |
|
return frame; |
} |
} |
|
|
|
|
|
constant deflate_default_options = ([ |
"compressionLevel":3, |
"compressionThreshold":5, |
"compressionThresholdNoContext":256, |
"compressionStrategy":Gz.DEFAULT_STRATEGY, |
"compressionWindowSize":15, |
"decompressionWindowSize":15, |
"compressionHeuristics":HEURISTICS_COMPRESS, |
]); |
#endif |
|
|
|
|
|
|
|
|
|
object permessagedeflate(void|mapping default_options) { |
#if constant(Gz.deflate) |
default_options = deflate_default_options + (default_options||([])); |
|
object factory(int(0..1) client_mode, mapping ext, mapping rext) { |
|
if (client_mode && !ext) { |
|
|
rext["permessage-deflate"] = ([]); |
return 0; |
} |
|
mapping parm = ext["permessage-deflate"]; |
|
|
if (!parm) return defragment(); |
|
mapping options = default_options + ([]); |
|
if (!client_mode) { |
mapping rparm = ([]); |
|
mixed p; |
|
if (parm->client_no_context_takeover |
|| options->decompressionNoContextTakeover) { |
options->decompressionNoContextTakeover = 1; |
rparm->client_no_context_takeover = ""; |
} |
if (stringp(p = parm->client_max_window_bits)) { |
if ((p = options->decompressionWindowSize) < 15 && p > 8) |
rparm->client_max_window_bits = p; |
} else if (!zero_type(p)) { |
p = min(p, options->decompressionWindowSize); |
options->decompressionWindowSize |
= max(rparm->client_max_window_bits = p, 8); |
} |
if (parm->server_no_context_takeover |
|| options->compressionNoContextTakeover) { |
options->compressionNoContextTakeover = 1; |
rparm->server_no_context_takeover = ""; |
} |
if (stringp(p = parm->server_max_window_bits)) { |
if ((p = options->compressionWindowSize) < 15) |
rparm->server_max_window_bits = p; |
} else if (!zero_type(p)) { |
p = min(p, options->compressionWindowSize); |
if (p >= 8) |
options->compressionWindowSize |
= rparm->server_max_window_bits = p; |
} |
|
rext["permessage-deflate"] = rparm; |
} |
|
return _permessagedeflate(options); |
}; |
#else |
object factory() { |
return defragment(); |
} |
#endif |
|
return factory; |
} |
|
|
|
class conformance_check { |
inherit Extension; |
|
Frame receive(Frame frame, Connection con) { |
int opcode = frame->opcode; |
|
if (opcode == FRAME_TEXT && catch(frame->text)) { |
con->fail(CLOSE_BAD_DATA); |
WS_WERR(1, "Invalid utf8 in text frame.\n"); |
return 0; |
} |
|
if (frame->rsv & (RSV1|RSV2|RSV3)) { |
con->fail(CLOSE_EXTENSION); |
WS_WERR(1, "Unexpected rsv bits.\n"); |
return 0; |
} |
|
if (opcode & 0x8) { |
|
if (sizeof(frame->data) > 125) { |
WS_WERR(1, "Received too big control frame. closing connection.\n"); |
con->fail(); |
return 0; |
} |
if (opcode > FRAME_PONG) { |
WS_WERR(1, "Received unknown control frame. closing connection.\n"); |
con->fail(); |
return 0; |
} |
} |
if (opcode >= 0x3 && opcode <= 0x7) { |
WS_WERR(1, "Received reserved non control opcode frame.\n"); |
con->fail(); |
return 0; |
} |
|
return frame; |
} |
} |
|
|