diff options
author | Badlop <badlop@process-one.net> | 2013-03-19 13:29:15 +0100 |
---|---|---|
committer | Badlop <badlop@process-one.net> | 2013-03-19 13:30:17 +0100 |
commit | f92a94a7378856c2ff3217152ee62a9d035aadc6 (patch) | |
tree | 119ef937e6ebc97790d655145bf29fa394bf1177 /src/pgsql/pgsql_proto.erl | |
parent | Copied MySQL erlang library from ejabberd-modules SVN (diff) |
Copied PostgreSQL erlang library from ejabberd-modules SVN
Diffstat (limited to 'src/pgsql/pgsql_proto.erl')
-rw-r--r-- | src/pgsql/pgsql_proto.erl | 650 |
1 files changed, 650 insertions, 0 deletions
diff --git a/src/pgsql/pgsql_proto.erl b/src/pgsql/pgsql_proto.erl new file mode 100644 index 000000000..fe49a4846 --- /dev/null +++ b/src/pgsql/pgsql_proto.erl @@ -0,0 +1,650 @@ +%%% File : pgsql_proto.erl +%%% Author : Christian Sunesson <chrisu@kth.se> +%%% Description : PostgreSQL protocol driver +%%% Created : 9 May 2005 + +%%% This is the protocol handling part of the PostgreSQL driver, it turns packages into +%%% erlang term messages and back. + +-module(pgsql_proto). + +-behaviour(gen_server). + +%% TODO: +%% When factorizing make clear distinction between message and packet. +%% Packet == binary on-wire representation +%% Message = parsed Packet as erlang terms. + +%%% Version 3.0 of the protocol. +%%% Supported in postgres from version 7.4 +-define(PROTOCOL_MAJOR, 3). +-define(PROTOCOL_MINOR, 0). + +%%% PostgreSQL protocol message codes +-define(PG_BACKEND_KEY_DATA, $K). +-define(PG_PARAMETER_STATUS, $S). +-define(PG_ERROR_MESSAGE, $E). +-define(PG_NOTICE_RESPONSE, $N). +-define(PG_EMPTY_RESPONSE, $I). +-define(PG_ROW_DESCRIPTION, $T). +-define(PG_DATA_ROW, $D). +-define(PG_READY_FOR_QUERY, $Z). +-define(PG_AUTHENTICATE, $R). +-define(PG_BIND, $B). +-define(PG_PARSE, $P). +-define(PG_COMMAND_COMPLETE, $C). +-define(PG_PARSE_COMPLETE, $1). +-define(PG_BIND_COMPLETE, $2). +-define(PG_CLOSE_COMPLETE, $3). +-define(PG_PORTAL_SUSPENDED, $s). +-define(PG_NO_DATA, $n). + +-export([start/1, start_link/1]). + +%% gen_server callbacks +-export([init/1, + handle_call/3, + handle_cast/2, + code_change/3, + handle_info/2, + terminate/2]). + +%% For protocol unwrapping, pgsql_tcp for example. +-export([decode_packet/3]). +-export([encode_message/2]). +-export([encode/2]). + +-import(pgsql_util, [option/3]). +-import(pgsql_util, [socket/1]). +-import(pgsql_util, [send/2, send_int/2, send_msg/3]). +-import(pgsql_util, [recv_msg/2, recv_msg/1, recv_byte/2, recv_byte/1]). +-import(pgsql_util, [string/1, make_pair/2, split_pair/2]). +-import(pgsql_util, [count_string/1, to_string/2]). +-import(pgsql_util, [coldescs/3, datacoldescs/3]). +-import(pgsql_util, [to_integer/1, to_atom/1]). + +-record(state, {options, driver, params, socket, oidmap, as_binary}). + +start(Options) -> + gen_server:start(?MODULE, [self(), Options], []). + +start_link(Options) -> + gen_server:start_link(?MODULE, [self(), Options], []). + +init([DriverPid, Options]) -> + %%io:format("Init~n", []), + %% Default values: We connect to localhost on the standard TCP/IP + %% port. + Host = option(Options, host, "localhost"), + Port = option(Options, port, 5432), + AsBinary = option(Options, as_binary, false), + + case socket({tcp, Host, Port}) of + {ok, Sock} -> + connect(#state{options = Options, + driver = DriverPid, + as_binary = AsBinary, + socket = Sock}); + Error -> + Reason = {init, Error}, + {stop, Reason} + end. + +connect(StateData) -> + %%io:format("Connect~n", []), + %% Connection settings for database-login. + %% TODO: Check if the default values are relevant: + UserName = option(StateData#state.options, user, "cos"), + DatabaseName = option(StateData#state.options, database, "template1"), + + %% Make protocol startup packet. + Version = <<?PROTOCOL_MAJOR:16/integer, ?PROTOCOL_MINOR:16/integer>>, + User = make_pair(user, UserName), + Database = make_pair(database, DatabaseName), + StartupPacket = <<Version/binary, + User/binary, + Database/binary, + 0>>, + + %% Backend will continue with authentication after the startup packet + PacketSize = 4 + size(StartupPacket), + Sock = StateData#state.socket, + ok = gen_tcp:send(Sock, <<PacketSize:32/integer, StartupPacket/binary>>), + authenticate(StateData). + + +authenticate(StateData) -> + %% Await authentication request from backend. + Sock = StateData#state.socket, + AsBin = StateData#state.as_binary, + {ok, Code, Packet} = recv_msg(Sock, 5000), + {ok, Value} = decode_packet(Code, Packet, AsBin), + case Value of + %% Error response + {error_message, Message} -> + {stop, {authentication, Message}}; + {authenticate, {AuthMethod, Salt}} -> + case AuthMethod of + 0 -> % Auth ok + setup(StateData, []); + 1 -> % Kerberos 4 + {stop, {nyi, auth_kerberos4}}; + 2 -> % Kerberos 5 + {stop, {nyi, auth_kerberos5}}; + 3 -> % Plaintext password + Password = option(StateData#state.options, password, ""), + EncodedPass = encode_message(pass_plain, Password), + ok = send(Sock, EncodedPass), + authenticate(StateData); + 4 -> % Hashed password + {stop, {nyi, auth_crypt}}; + 5 -> % MD5 password + Password = option(StateData#state.options, password, ""), + User = option(StateData#state.options, user, ""), + EncodedPass = encode_message(pass_md5, + {User, Password, Salt}), + ok = send(Sock, EncodedPass), + authenticate(StateData); + _ -> + {stop, {authentication, {unknown, AuthMethod}}} + end; + %% Unknown message received + Any -> + {stop, {protocol_error, Any}} + end. + +setup(StateData, Params) -> + %% Receive startup messages until ReadyForQuery + Sock = StateData#state.socket, + AsBin = StateData#state.as_binary, + {ok, Code, Package} = recv_msg(Sock, 5000), + {ok, Pair} = decode_packet(Code, Package, AsBin), + case Pair of + %% BackendKeyData, necessary for issuing cancel requests + {backend_key_data, {Pid, Secret}} -> + Params1 = [{secret, {Pid, Secret}} | Params], + setup(StateData, Params1); + %% ParameterStatus, a key-value pair. + {parameter_status, {Key, Value}} -> + Params1 = [{{parameter, Key}, Value} | Params], + setup(StateData, Params1); + %% Error message, with a sequence of <<Code:8/integer, String, 0>> + %% of error descriptions. Code==0 terminates the Reason. + {error_message, Message} -> + gen_tcp:close(Sock), + {stop, {error_response, Message}}; + %% Notice Response, with a sequence of <<Code:8/integer, String,0>> + %% identified fields. Code==0 terminates the Notice. + {notice_response, Notice} -> + deliver(StateData, {pgsql_notice, Notice}), + setup(StateData, Params); + %% Ready for Query, backend is ready for a new query cycle + {ready_for_query, _Status} -> + connected(StateData#state{params = Params}, Sock); + Any -> + {stop, {unknown_setup, Any}} + end. + +%% Connected state. Can now start to push messages +%% between frontend and backend. But first some setup. +connected(StateData, Sock) -> + %% Protocol unwrapping process. Factored out to make future + %% SSL and unix domain support easier. Store process under + %% 'socket' in the process dictionary. + AsBin = StateData#state.as_binary, + {ok, Unwrapper} = pgsql_tcp:start_link(Sock, self(), AsBin), + ok = gen_tcp:controlling_process(Sock, Unwrapper), + + %% Lookup oid to type names and store them in a dictionary under + %% 'oidmap' in the process dictionary. + Packet = encode_message(squery, "SELECT oid, typname FROM pg_type"), + ok = send(Sock, Packet), + {ok, [{_, _ColDesc, Rows}]} = process_squery([], AsBin), + Rows1 = lists:map(fun ([CodeS, NameS]) -> + Code = to_integer(CodeS), + Name = to_atom(NameS), + {Code, Name} + end, + Rows), + OidMap = dict:from_list(Rows1), + + {ok, StateData#state{oidmap = OidMap}}. + + +handle_call(terminate, _From, State) -> + Sock = State#state.socket, + Packet = encode_message(terminate, []), + ok = send(Sock, Packet), + gen_tcp:close(Sock), + Reply = ok, + {stop, normal, Reply, State}; + +%% Simple query +handle_call({squery, Query}, _From, State) -> + Sock = State#state.socket, + AsBin = State#state.as_binary, + Packet = encode_message(squery, Query), + ok = send(Sock, Packet), + {ok, Result} = process_squery([], AsBin), + case lists:keymember(error, 1, Result) of + true -> + RBPacket = encode_message(squery, "ROLLBACK"), + ok = send(Sock, RBPacket), + {ok, _RBResult} = process_squery([], AsBin); + _ -> + ok + end, + Reply = {ok, Result}, + {reply, Reply, State}; + +%% Extended query +%% simplistic version using the unnammed prepared statement and portal. +handle_call({equery, {Query, Params}}, _From, State) -> + Sock = State#state.socket, + ParseP = encode_message(parse, {"", Query, []}), + BindP = encode_message(bind, {"", "", Params, [binary]}), + DescribeP = encode_message(describe, {portal, ""}), + ExecuteP = encode_message(execute, {"", 0}), + SyncP = encode_message(sync, []), + ok = send(Sock, [ParseP, BindP, DescribeP, ExecuteP, SyncP]), + + {ok, Command, Desc, Status, Logs} = process_equery(State, []), + + OidMap = State#state.oidmap, + NameTypes = lists:map(fun({Name, _Format, _ColNo, Oid, _, _, _}) -> + {Name, dict:fetch(Oid, OidMap)} + end, + Desc), + Reply = {ok, Command, Status, NameTypes, Logs}, + {reply, Reply, State}; + +%% Prepare a statement, so it can be used for queries later on. +handle_call({prepare, {Name, Query}}, _From, State) -> + Sock = State#state.socket, + send_message(Sock, parse, {Name, Query, []}), + send_message(Sock, describe, {prepared_statement, Name}), + send_message(Sock, sync, []), + {ok, State, ParamDesc, ResultDesc} = process_prepare({[], []}), + OidMap = State#state.oidmap, + ParamTypes = + lists:map(fun (Oid) -> dict:fetch(Oid, OidMap) end, ParamDesc), + ResultNameTypes = lists:map(fun ({ColName, _Format, _ColNo, Oid, _, _, _}) -> + {ColName, dict:fetch(Oid, OidMap)} + end, + ResultDesc), + Reply = {ok, State, ParamTypes, ResultNameTypes}, + {reply, Reply, State}; + +%% Close a prepared statement. +handle_call({unprepare, Name}, _From, State) -> + Sock = State#state.socket, + send_message(Sock, close, {prepared_statement, Name}), + send_message(Sock, sync, []), + {ok, _Status} = process_unprepare(), + Reply = ok, + {reply, Reply, State}; + +%% Execute a prepared statement +handle_call({execute, {Name, Params}}, _From, State) -> + Sock = State#state.socket, + %%io:format("execute: ~p ~p ~n", [Name, Params]), + begin % Issue first requests for the prepared statement. + BindP = encode_message(bind, {"", Name, Params, [binary]}), + DescribeP = encode_message(describe, {portal, ""}), + ExecuteP = encode_message(execute, {"", 0}), + FlushP = encode_message(flush, []), + ok = send(Sock, [BindP, DescribeP, ExecuteP, FlushP]) + end, + receive + {pgsql, {bind_complete, _}} -> % Bind reply first. + %% Collect response to describe message, + %% which gives a hint of the rest of the messages. + {ok, Command, Result} = process_execute(State, Sock), + + begin % Close portal and end extended query. + CloseP = encode_message(close, {portal, ""}), + SyncP = encode_message(sync, []), + ok = send(Sock, [CloseP, SyncP]) + end, + receive + %% Collect response to close message. + {pgsql, {close_complete, _}} -> + receive + %% Collect response to sync message. + {pgsql, {ready_for_query, _Status}} -> + %%io:format("execute: ~p ~p ~p~n", + %% [Status, Command, Result]), + Reply = {ok, {Command, Result}}, + {reply, Reply, State}; + {pgsql, Unknown} -> + {stop, Unknown, {error, Unknown}, State} + end; + {pgsql, Unknown} -> + {stop, Unknown, {error, Unknown}, State} + end; + {pgsql, Unknown} -> + {stop, Unknown, {error, Unknown}, State} + end; + +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. + + +handle_cast(_Msg, State) -> + {noreply, State}. + + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + + +%% Socket closed or socket error messages. +handle_info({socket, _Sock, Condition}, State) -> + {stop, {socket, Condition}, State}; +handle_info(_Info, State) -> + {noreply, State}. + + +terminate(_Reason, _State) -> + ok. + + +deliver(State, Message) -> + DriverPid = State#state.driver, + DriverPid ! Message. + +%% In the process_squery state we collect responses until the backend is +%% done processing. +process_squery(Log, AsBin) -> + receive + {pgsql, {row_description, Cols}} -> + {ok, Command, Rows} = process_squery_cols([], AsBin), + process_squery([{Command, Cols, Rows}|Log], AsBin); + {pgsql, {command_complete, Command}} -> + process_squery([Command|Log], AsBin); + {pgsql, {ready_for_query, _Status}} -> + {ok, lists:reverse(Log)}; + {pgsql, {error_message, Error}} -> + process_squery([{error, Error}|Log], AsBin); + {pgsql, _Any} -> + process_squery(Log, AsBin) + end. +process_squery_cols(Log, AsBin) -> + receive + {pgsql, {data_row, Row}} -> + process_squery_cols( + [lists:map( + fun(null) -> + null; + (R) when AsBin == true -> + R; + (R) -> + binary_to_list(R) + end, Row) | Log], AsBin); + {pgsql, {command_complete, Command}} -> + {ok, Command, lists:reverse(Log)} + end. + +process_equery(State, Log) -> + receive + %% Consume parse and bind complete messages when waiting for the first + %% first row_description message. What happens if the equery doesnt + %% return a result set? + {pgsql, {parse_complete, _}} -> + process_equery(State, Log); + {pgsql, {bind_complete, _}} -> + process_equery(State, Log); + {pgsql, {row_description, Descs}} -> + OidMap = State#state.oidmap, + {ok, Descs1} = pgsql_util:decode_descs(OidMap, Descs), + process_equery_datarow(Descs1, Log, {undefined, Descs, undefined}); + {pgsql, Any} -> + process_equery(State, [Any|Log]) + end. + +process_equery_datarow(Types, Log, Info={Command, Desc, Status}) -> + receive + %% + {pgsql, {command_complete, Command1}} -> + process_equery_datarow(Types, Log, {Command1, Desc, Status}); + {pgsql, {ready_for_query, Status1}} -> + {ok, Command, Desc, Status1, lists:reverse(Log)}; + {pgsql, {data_row, Row}} -> + {ok, DecodedRow} = pgsql_util:decode_row(Types, Row), + process_equery_datarow(Types, [DecodedRow|Log], Info); + {pgsql, Any} -> + process_equery_datarow(Types, [Any|Log], Info) + end. + +process_prepare(Info={ParamDesc, ResultDesc}) -> + receive + {pgsql, {no_data, _}} -> + process_prepare({ParamDesc, []}); + {pgsql, {parse_complete, _}} -> + process_prepare(Info); + {pgsql, {parameter_description, Oids}} -> + process_prepare({Oids, ResultDesc}); + {pgsql, {row_description, Desc}} -> + process_prepare({ParamDesc, Desc}); + {pgsql, {ready_for_query, Status}} -> + {ok, Status, ParamDesc, ResultDesc}; + {pgsql, Any} -> + io:format("process_prepare: ~p~n", [Any]), + process_prepare(Info) + end. + +process_unprepare() -> + receive + {pgsql, {ready_for_query, Status}} -> + {ok, Status}; + {pgsql, {close_complate, []}} -> + process_unprepare(); + {pgsql, Any} -> + io:format("process_unprepare: ~p~n", [Any]), + process_unprepare() + end. + +process_execute(State, Sock) -> + %% Either the response begins with a no_data or a row_description + %% Needs to return {ok, Status, Result} + %% where Result = {Command, ...} + receive + {pgsql, {no_data, _}} -> + {ok, _Command, _Result} = process_execute_nodata(); + {pgsql, {row_description, Descs}} -> + OidMap = State#state.oidmap, + {ok, Types} = pgsql_util:decode_descs(OidMap, Descs), + {ok, _Command, _Result} = + process_execute_resultset(Sock, Types, []); + {pgsql, Unknown} -> + exit(Unknown) + end. + +process_execute_nodata() -> + receive + {pgsql, {command_complete, Cmd}} -> + Command = if is_binary(Cmd) -> + binary_to_list(Cmd); + true -> + Cmd + end, + case Command of + "INSERT "++Rest -> + {ok, [{integer, _, _Table}, + {integer, _, NRows}], _} = erl_scan:string(Rest), + {ok, 'INSERT', NRows}; + "SELECT" -> + {ok, 'SELECT', should_not_happen}; + "DELETE "++Rest -> + {ok, [{integer, _, NRows}], _} = + erl_scan:string(Rest), + {ok, 'DELETE', NRows}; + Any -> + {ok, nyi, Any} + end; + + {pgsql, Unknown} -> + exit(Unknown) + end. +process_execute_resultset(Sock, Types, Log) -> + receive + {pgsql, {command_complete, Command}} -> + {ok, to_atom(Command), lists:reverse(Log)}; + {pgsql, {data_row, Row}} -> + {ok, DecodedRow} = pgsql_util:decode_row(Types, Row), + process_execute_resultset(Sock, Types, [DecodedRow|Log]); + {pgsql, {portal_suspended, _}} -> + throw(portal_suspended); + {pgsql, Any} -> + %%process_execute_resultset(Types, [Any|Log]) + exit(Any) + end. + +%% With a message type Code and the payload Packet apropriate +%% decoding procedure can proceed. +decode_packet(Code, Packet, AsBin) -> + Ret = fun(CodeName, Values) -> {ok, {CodeName, Values}} end, + case Code of + ?PG_ERROR_MESSAGE -> + Message = pgsql_util:errordesc(Packet, AsBin), + Ret(error_message, Message); + ?PG_EMPTY_RESPONSE -> + Ret(empty_response, []); + ?PG_ROW_DESCRIPTION -> + <<_Columns:16/integer, ColDescs/binary>> = Packet, + Descs = coldescs(ColDescs, [], AsBin), + Ret(row_description, Descs); + ?PG_READY_FOR_QUERY -> + <<State:8/integer>> = Packet, + case State of + $I -> + Ret(ready_for_query, idle); + $T -> + Ret(ready_for_query, transaction); + $E -> + Ret(ready_for_query, failed_transaction) + end; + ?PG_COMMAND_COMPLETE -> + {Task, _} = to_string(Packet, AsBin), + Ret(command_complete, Task); + ?PG_DATA_ROW -> + <<NumberCol:16/integer, RowData/binary>> = Packet, + ColData = datacoldescs(NumberCol, RowData, []), + Ret(data_row, ColData); + ?PG_BACKEND_KEY_DATA -> + <<Pid:32/integer, Secret:32/integer>> = Packet, + Ret(backend_key_data, {Pid, Secret}); + ?PG_PARAMETER_STATUS -> + {Key, Value} = split_pair(Packet, AsBin), + Ret(parameter_status, {Key, Value}); + ?PG_NOTICE_RESPONSE -> + Ret(notice_response, []); + ?PG_AUTHENTICATE -> + <<AuthMethod:32/integer, Salt/binary>> = Packet, + Ret(authenticate, {AuthMethod, Salt}); + ?PG_PARSE_COMPLETE -> + Ret(parse_complete, []); + ?PG_BIND_COMPLETE -> + Ret(bind_complete, []); + ?PG_PORTAL_SUSPENDED -> + Ret(portal_suspended, []); + ?PG_CLOSE_COMPLETE -> + Ret(close_complete, []); + $t -> + <<_NParams:16/integer, OidsP/binary>> = Packet, + Oids = pgsql_util:oids(OidsP, []), + Ret(parameter_description, Oids); + ?PG_NO_DATA -> + Ret(no_data, []); + _Any -> + Ret(unknown, [Code]) + end. + +send_message(Sock, Type, Values) -> + %%io:format("send_message:~p~n", [{Type, Values}]), + Packet = encode_message(Type, Values), + ok = send(Sock, Packet). + +%% Add header to a message. +encode(Code, Packet) -> + Len = size(Packet) + 4, + <<Code:8/integer, Len:4/integer-unit:8, Packet/binary>>. + +%% Encode a message of a given type. +encode_message(pass_plain, Password) -> + Pass = pgsql_util:pass_plain(Password), + encode($p, Pass); +encode_message(pass_md5, {User, Password, Salt}) -> + Pass = pgsql_util:pass_md5(User, Password, Salt), + encode($p, Pass); +encode_message(terminate, _) -> + encode($X, <<>>); +encode_message(squery, Query) -> % squery as in simple query. + encode($Q, string(Query)); +encode_message(close, {Object, Name}) -> + Type = case Object of prepared_statement -> $S; portal -> $P end, + String = string(Name), + encode($C, <<Type/integer, String/binary>>); +encode_message(describe, {Object, Name}) -> + ObjectP = case Object of prepared_statement -> $S; portal -> $P end, + NameP = string(Name), + encode($D, <<ObjectP:8/integer, NameP/binary>>); +encode_message(flush, _) -> + encode($H, <<>>); +encode_message(parse, {Name, Query, _Oids}) -> + StringName = string(Name), + StringQuery = string(Query), + encode($P, <<StringName/binary, StringQuery/binary, 0:16/integer>>); +encode_message(bind, {NamePortal, NamePrepared, + Parameters, ResultFormats}) -> + PortalP = string(NamePortal), + PreparedP = string(NamePrepared), + + ParamFormatsList = lists:map( + fun (Bin) when is_binary(Bin) -> <<1:16/integer>>; + (_Text) -> <<0:16/integer>> end, + Parameters), + ParamFormatsP = erlang:list_to_binary(ParamFormatsList), + + NParameters = length(Parameters), + ParametersList = lists:map( + fun (null) -> + Minus = -1, + <<Minus:32/integer>>; + (Bin) when is_binary(Bin) -> + Size = size(Bin), + <<Size:32/integer, Bin/binary>>; + (Integer) when is_integer(Integer) -> + List = integer_to_list(Integer), + Bin = list_to_binary(List), + Size = size(Bin), + <<Size:32/integer, Bin/binary>>; + (Text) -> + Bin = list_to_binary(Text), + Size = size(Bin), + <<Size:32/integer, Bin/binary>> + end, + Parameters), + ParametersP = erlang:list_to_binary(ParametersList), + + NResultFormats = length(ResultFormats), + ResultFormatsList = lists:map( + fun (binary) -> <<1:16/integer>>; + (text) -> <<0:16/integer>> end, + ResultFormats), + ResultFormatsP = erlang:list_to_binary(ResultFormatsList), + + %%io:format("encode bind: ~p~n", [{PortalP, PreparedP, + %% NParameters, ParamFormatsP, + %% NParameters, ParametersP, + %% NResultFormats, ResultFormatsP}]), + encode($B, <<PortalP/binary, PreparedP/binary, + NParameters:16/integer, ParamFormatsP/binary, + NParameters:16/integer, ParametersP/binary, + NResultFormats:16/integer, ResultFormatsP/binary>>); +encode_message(execute, {Portal, Limit}) -> + String = string(Portal), + encode($E, <<String/binary, Limit:32/integer>>); +encode_message(sync, _) -> + encode($S, <<>>). |