conn.go 45 KB


  1. package gokb
  2. import (
  3. "bufio"
  4. "context"
  5. "crypto/md5"
  6. "crypto/sha256"
  7. "database/sql"
  8. "database/sql/driver"
  9. "encoding/binary"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "net"
  14. "os"
  15. "os/user"
  16. "path"
  17. "path/filepath"
  18. "strconv"
  19. "strings"
  20. "time"
  21. "unicode"
  22. "kingbase.com/gokb/oid"
  23. "kingbase.com/gokb/scram"
  24. )
  25. const GOKBVersion_V008R003C002B1009 = iota
  26. // Common error types
  27. var (
  28. ErrNotSupported = errors.New("kb: Unsupported command")
  29. ErrInFailedTransaction = errors.New("kb: Could not complete operation in a failed transaction")
  30. ErrSSLNotSupported = errors.New("kb: SSL is not enabled on the server")
  31. ErrSSLKeyHasWorldPermissions = errors.New("kb: Private key file has group or world access. Permissions should be u=rw (0600) or less")
  32. ErrCouldNotDetectUsername = errors.New("kb: Could not detect default username. Please provide one explicitly")
  33. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  34. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  35. errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
  36. )
  37. // Driver is the Kingbase database driver.
  38. type Driver struct{}
  39. // Open opens a new connection to the database. name is a connection string.
  40. // Most users should only use it through database/sql package from the standard
  41. // library.
  42. func (d *Driver) Open(name string) (driver.Conn, error) {
  43. return Open(name)
  44. }
  45. func init() {
  46. sql.Register("kingbase", &Driver{})
  47. }
  48. type parameterStatus struct {
  49. // server version in the same format as server_version_num, or 0 if
  50. // unavailable
  51. serverVersion int
  52. // the current location based on the TimeZone value of the session, if
  53. // available
  54. currentLocation *time.Location
  55. }
  56. type transactionStatus byte
  57. const (
  58. txnStatusIdle transactionStatus = 'I'
  59. txnStatusIdleInTransaction transactionStatus = 'T'
  60. txnStatusInFailedTransaction transactionStatus = 'E'
  61. )
  62. func (s transactionStatus) String() string {
  63. switch s {
  64. case txnStatusIdle:
  65. return "idle"
  66. case txnStatusIdleInTransaction:
  67. return "idle in transaction"
  68. case txnStatusInFailedTransaction:
  69. return "in a failed transaction"
  70. default:
  71. errorf("unknown transactionStatus %d", s)
  72. }
  73. panic("not reached")
  74. }
  75. // Dialer is the dialer interface. It can be used to obtain more control over
  76. // how kb creates network connections.
  77. type Dialer interface {
  78. Dial(network, address string) (net.Conn, error)
  79. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  80. }
  81. // DialerContext is the context-aware dialer interface.
  82. type DialerContext interface {
  83. DialContext(ctx context.Context, network, address string) (net.Conn, error)
  84. }
  85. type defaultDialer struct {
  86. d net.Dialer
  87. }
  88. func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
  89. return d.d.Dial(network, address)
  90. }
  91. func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
  92. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  93. defer cancel()
  94. return d.DialContext(ctx, network, address)
  95. }
  96. func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  97. return d.d.DialContext(ctx, network, address)
  98. }
  99. type conn struct {
  100. c net.Conn
  101. buf *bufio.Reader
  102. namei int
  103. scratch [512]byte
  104. txnStatus transactionStatus
  105. txnFinish func()
  106. // Save connection arguments to use during CancelRequest.
  107. dialer Dialer
  108. opts values
  109. // Cancellation key data for use with CancelRequest messages.
  110. processID int
  111. secretKey int
  112. parameterStatus parameterStatus
  113. saveMessageType byte
  114. saveMessageBuffer []byte
  115. // If true, this connection is bad and all public-facing functions should
  116. // return ErrBadConn.
  117. bad bool
  118. // If set, this connection should never use the binary format when
  119. // receiving query results from prepared statements. Only provided for
  120. // debugging.
  121. disablePreparedBinaryResult bool
  122. // Whether to always send []byte parameters over as binary. Enables single
  123. // round-trip mode for non-prepared Query calls.
  124. binaryParameters bool
  125. // If true this connection is in the middle of a COPY
  126. inCopy bool
  127. // If not nil, notices will be synchronously sent here
  128. noticeHandler func(*Error)
  129. }
  130. // Handle driver-side settings in parsed connection string.
  131. func (cn *conn) handleDriverSettings(o values) (err error) {
  132. boolSetting := func(key string, val *bool) error {
  133. if value, ok := o[key]; ok {
  134. if value == "yes" {
  135. *val = true
  136. } else if value == "no" {
  137. *val = false
  138. } else {
  139. return fmt.Errorf("unrecognized value %q for %s", value, key)
  140. }
  141. }
  142. return nil
  143. }
  144. err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
  145. if err != nil {
  146. return err
  147. }
  148. return boolSetting("binary_parameters", &cn.binaryParameters)
  149. }
  150. func (cn *conn) handleKbpass(o values) {
  151. // if a password was supplied, do not process .kbpass
  152. if _, ok := o["password"]; ok {
  153. return
  154. }
  155. filename := os.Getenv("KBPASSFILE")
  156. if filename == "" {
  157. // XXX this code doesn't work on Windows where the default filename is
  158. // XXX %APPDATA%\kingbase\kbpass.conf
  159. // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
  160. userHome := os.Getenv("HOME")
  161. if userHome == "" {
  162. user, err := user.Current()
  163. if err != nil {
  164. return
  165. }
  166. userHome = user.HomeDir
  167. }
  168. filename = filepath.Join(userHome, ".kbpass")
  169. }
  170. fileinfo, err := os.Stat(filename)
  171. if err != nil {
  172. return
  173. }
  174. mode := fileinfo.Mode()
  175. if mode&(0x77) != 0 {
  176. // XXX should warn about incorrect .kbpass permissions as psql does
  177. return
  178. }
  179. file, err := os.Open(filename)
  180. if err != nil {
  181. return
  182. }
  183. defer file.Close()
  184. scanner := bufio.NewScanner(io.Reader(file))
  185. hostname := o["host"]
  186. ntw, _ := network(o)
  187. port := o["port"]
  188. db := o["dbname"]
  189. username := o["user"]
  190. // From: https://github.com/tg/kbpass/blob/master/reader.go
  191. getFields := func(s string) []string {
  192. fs := make([]string, 0, 5)
  193. f := make([]rune, 0, len(s))
  194. var esc bool
  195. for _, c := range s {
  196. switch {
  197. case esc:
  198. f = append(f, c)
  199. esc = false
  200. case c == '\\':
  201. esc = true
  202. case c == ':':
  203. fs = append(fs, string(f))
  204. f = f[:0]
  205. default:
  206. f = append(f, c)
  207. }
  208. }
  209. return append(fs, string(f))
  210. }
  211. for scanner.Scan() {
  212. line := scanner.Text()
  213. if len(line) == 0 || line[0] == '#' {
  214. continue
  215. }
  216. split := getFields(line)
  217. if len(split) != 5 {
  218. continue
  219. }
  220. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  221. o["password"] = split[4]
  222. return
  223. }
  224. }
  225. }
  226. func (cn *conn) writeBuf(b byte) *writeBuf {
  227. cn.scratch[0] = b
  228. return &writeBuf{
  229. buf: cn.scratch[:5],
  230. pos: 1,
  231. }
  232. }
  233. // Open opens a new connection to the database. dsn is a connection string.
  234. // Most users should only use it through database/sql package from the standard
  235. // library.
  236. func Open(dsn string) (_ driver.Conn, err error) {
  237. return DialOpen(defaultDialer{}, dsn)
  238. }
  239. // DialOpen opens a new connection to the database using a dialer.
  240. func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
  241. c, err := NewConnector(dsn)
  242. if err != nil {
  243. return nil, err
  244. }
  245. c.dialer = d
  246. return c.open(context.Background())
  247. }
  248. func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
  249. // Handle any panics during connection initialization. Note that we
  250. // specifically do *not* want to use errRecover(), as that would turn any
  251. // connection errors into ErrBadConns, hiding the real error message from
  252. // the user.
  253. defer errRecoverNoErrBadConn(&err)
  254. o := c.opts
  255. cn = &conn{
  256. opts: o,
  257. dialer: c.dialer,
  258. }
  259. err = cn.handleDriverSettings(o)
  260. if err != nil {
  261. return nil, err
  262. }
  263. cn.handleKbpass(o)
  264. cn.c, err = dial(ctx, c.dialer, o)
  265. if err != nil {
  266. return nil, err
  267. }
  268. err = cn.ssl(o)
  269. if err != nil {
  270. if cn.c != nil {
  271. cn.c.Close()
  272. }
  273. return nil, err
  274. }
  275. // cn.startup panics on error. Make sure we don't leak cn.c.
  276. panicking := true
  277. defer func() {
  278. if panicking {
  279. cn.c.Close()
  280. }
  281. }()
  282. cn.buf = bufio.NewReader(cn.c)
  283. cn.startup(o)
  284. // reset the deadline, in case one was set (see dial)
  285. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  286. err = cn.c.SetDeadline(time.Time{})
  287. }
  288. panicking = false
  289. return cn, err
  290. }
  291. func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
  292. network, address := network(o)
  293. // SSL is not necessary or supported over UNIX domain sockets
  294. if network == "unix" {
  295. o["sslmode"] = "disable"
  296. }
  297. // Zero or not specified means wait indefinitely.
  298. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  299. seconds, err := strconv.ParseInt(timeout, 10, 0)
  300. if err != nil {
  301. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  302. }
  303. duration := time.Duration(seconds) * time.Second
  304. // connect_timeout should apply to the entire connection establishment
  305. // procedure, so we both use a timeout for the TCP connection
  306. // establishment and set a deadline for doing the initial handshake.
  307. // The deadline is then reset after startup() is done.
  308. deadline := time.Now().Add(duration)
  309. var conn net.Conn
  310. if dctx, ok := d.(DialerContext); ok {
  311. ctx, cancel := context.WithTimeout(ctx, duration)
  312. defer cancel()
  313. conn, err = dctx.DialContext(ctx, network, address)
  314. } else {
  315. conn, err = d.DialTimeout(network, address, duration)
  316. }
  317. if err != nil {
  318. return nil, err
  319. }
  320. err = conn.SetDeadline(deadline)
  321. return conn, err
  322. }
  323. if dctx, ok := d.(DialerContext); ok {
  324. return dctx.DialContext(ctx, network, address)
  325. }
  326. return d.Dial(network, address)
  327. }
  328. func network(o values) (string, string) {
  329. host := o["host"]
  330. if strings.HasPrefix(host, "/") {
  331. sockPath := path.Join(host, ".s.KINGBASE."+o["port"])
  332. return "unix", sockPath
  333. }
  334. return "tcp", net.JoinHostPort(host, o["port"])
  335. }
  336. type values map[string]string
  337. // scanner implements a tokenizer for libkci-style option strings.
  338. type scanner struct {
  339. s []rune
  340. i int
  341. }
  342. // newScanner returns a new scanner initialized with the option string s.
  343. func newScanner(s string) *scanner {
  344. return &scanner{[]rune(s), 0}
  345. }
  346. // Next returns the next rune.
  347. // It returns 0, false if the end of the text has been reached.
  348. func (s *scanner) Next() (rune, bool) {
  349. if s.i >= len(s.s) {
  350. return 0, false
  351. }
  352. r := s.s[s.i]
  353. s.i++
  354. return r, true
  355. }
  356. // SkipSpaces returns the next non-whitespace rune.
  357. // It returns 0, false if the end of the text has been reached.
  358. func (s *scanner) SkipSpaces() (rune, bool) {
  359. r, ok := s.Next()
  360. for unicode.IsSpace(r) && ok {
  361. r, ok = s.Next()
  362. }
  363. return r, ok
  364. }
  365. // parseOpts parses the options from name and adds them to the values.
  366. //
  367. // The parsing code is based on conninfo_parse from libkci's fe-connect.c
  368. func parseOpts(name string, o values) error {
  369. s := newScanner(name)
  370. for {
  371. var (
  372. keyRunes, valRunes []rune
  373. r rune
  374. ok bool
  375. )
  376. if r, ok = s.SkipSpaces(); !ok {
  377. break
  378. }
  379. // Scan the key
  380. for !unicode.IsSpace(r) && r != '=' {
  381. keyRunes = append(keyRunes, r)
  382. if r, ok = s.Next(); !ok {
  383. break
  384. }
  385. }
  386. // Skip any whitespace if we're not at the = yet
  387. if r != '=' {
  388. r, ok = s.SkipSpaces()
  389. }
  390. // The current character should be =
  391. if r != '=' || !ok {
  392. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  393. }
  394. // Skip any whitespace after the =
  395. if r, ok = s.SkipSpaces(); !ok {
  396. // If we reach the end here, the last value is just an empty string as per libkci.
  397. o[string(keyRunes)] = ""
  398. break
  399. }
  400. if r != '\'' {
  401. for !unicode.IsSpace(r) {
  402. if r == '\\' {
  403. if r, ok = s.Next(); !ok {
  404. return fmt.Errorf(`missing character after backslash`)
  405. }
  406. }
  407. valRunes = append(valRunes, r)
  408. if r, ok = s.Next(); !ok {
  409. break
  410. }
  411. }
  412. } else {
  413. quote:
  414. for {
  415. if r, ok = s.Next(); !ok {
  416. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  417. }
  418. switch r {
  419. case '\'':
  420. break quote
  421. case '\\':
  422. r, _ = s.Next()
  423. fallthrough
  424. default:
  425. valRunes = append(valRunes, r)
  426. }
  427. }
  428. }
  429. o[string(keyRunes)] = string(valRunes)
  430. }
  431. return nil
  432. }
  433. func (cn *conn) isInTransaction() bool {
  434. return cn.txnStatus == txnStatusIdleInTransaction ||
  435. cn.txnStatus == txnStatusInFailedTransaction
  436. }
  437. func (cn *conn) checkIsInTransaction(intxn bool) {
  438. if cn.isInTransaction() != intxn {
  439. cn.bad = true
  440. errorf("unexpected transaction status %v", cn.txnStatus)
  441. }
  442. }
  443. func (cn *conn) Begin() (_ driver.Tx, err error) {
  444. return cn.begin("")
  445. }
  446. func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
  447. if cn.bad {
  448. return nil, driver.ErrBadConn
  449. }
  450. defer cn.errRecover(&err)
  451. cn.checkIsInTransaction(false)
  452. _, commandTag, err := cn.simpleExec("BEGIN" + mode)
  453. if err != nil {
  454. return nil, err
  455. }
  456. if commandTag != "BEGIN" {
  457. cn.bad = true
  458. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  459. }
  460. if cn.txnStatus != txnStatusIdleInTransaction {
  461. cn.bad = true
  462. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  463. }
  464. return cn, nil
  465. }
  466. func (cn *conn) closeTxn() {
  467. if finish := cn.txnFinish; finish != nil {
  468. finish()
  469. }
  470. }
  471. func (cn *conn) Commit() (err error) {
  472. defer cn.closeTxn()
  473. if cn.bad {
  474. return driver.ErrBadConn
  475. }
  476. defer cn.errRecover(&err)
  477. cn.checkIsInTransaction(true)
  478. // We don't want the client to think that everything is okay if it tries
  479. // to commit a failed transaction. However, no matter what we return,
  480. // database/sql will release this connection back into the free connection
  481. // pool so we have to abort the current transaction here. Note that you
  482. // would get the same behaviour if you issued a COMMIT in a failed
  483. // transaction, so it's also the least surprising thing to do here.
  484. if cn.txnStatus == txnStatusInFailedTransaction {
  485. if err := cn.rollback(); err != nil {
  486. return err
  487. }
  488. return ErrInFailedTransaction
  489. }
  490. _, commandTag, err := cn.simpleExec("COMMIT")
  491. if err != nil {
  492. if cn.isInTransaction() {
  493. cn.bad = true
  494. }
  495. return err
  496. }
  497. if commandTag != "COMMIT" {
  498. cn.bad = true
  499. return fmt.Errorf("unexpected command tag %s", commandTag)
  500. }
  501. cn.checkIsInTransaction(false)
  502. return nil
  503. }
  504. func (cn *conn) Rollback() (err error) {
  505. defer cn.closeTxn()
  506. if cn.bad {
  507. return driver.ErrBadConn
  508. }
  509. defer cn.errRecover(&err)
  510. return cn.rollback()
  511. }
  512. func (cn *conn) rollback() (err error) {
  513. cn.checkIsInTransaction(true)
  514. _, commandTag, err := cn.simpleExec("ROLLBACK")
  515. if err != nil {
  516. if cn.isInTransaction() {
  517. cn.bad = true
  518. }
  519. return err
  520. }
  521. if commandTag != "ROLLBACK" {
  522. return fmt.Errorf("unexpected command tag %s", commandTag)
  523. }
  524. cn.checkIsInTransaction(false)
  525. return nil
  526. }
  527. func (cn *conn) gname() string {
  528. cn.namei++
  529. return strconv.FormatInt(int64(cn.namei), 10)
  530. }
  531. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  532. b := cn.writeBuf('Q')
  533. b.string(q)
  534. cn.send(b)
  535. for {
  536. t, r := cn.recv1()
  537. switch t {
  538. case 'C':
  539. res, commandTag = cn.parseComplete(r.string())
  540. case 'Z':
  541. cn.processReadyForQuery(r)
  542. if res == nil && err == nil {
  543. err = errUnexpectedReady
  544. }
  545. // done
  546. return
  547. case 'E':
  548. err = parseError(r)
  549. case 'I':
  550. res = emptyRows
  551. case 'T', 'D':
  552. // ignore any results
  553. default:
  554. cn.bad = true
  555. errorf("unknown response for simple query: %q", t)
  556. }
  557. }
  558. }
  559. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  560. defer cn.errRecover(&err)
  561. b := cn.writeBuf('Q')
  562. b.string(q)
  563. cn.send(b)
  564. for {
  565. t, r := cn.recv1()
  566. switch t {
  567. case 'C', 'I':
  568. // We allow queries which don't return any results through Query as
  569. // well as Exec. We still have to give database/sql a rows object
  570. // the user can close, though, to avoid connections from being
  571. // leaked. A "rows" with done=true works fine for that purpose.
  572. if err != nil {
  573. cn.bad = true
  574. errorf("unexpected message %q in simple query execution", t)
  575. }
  576. if res == nil {
  577. res = &rows{
  578. cn: cn,
  579. }
  580. }
  581. // Set the result and tag to the last command complete if there wasn't a
  582. // query already run. Although queries usually return from here and cede
  583. // control to Next, a query with zero results does not.
  584. if t == 'C' && res.colNames == nil {
  585. res.result, res.tag = cn.parseComplete(r.string())
  586. }
  587. res.done = true
  588. case 'Z':
  589. cn.processReadyForQuery(r)
  590. // done
  591. return
  592. case 'E':
  593. res = nil
  594. err = parseError(r)
  595. case 'D':
  596. if res == nil {
  597. cn.bad = true
  598. errorf("unexpected DataRow in simple query execution")
  599. }
  600. // the query didn't fail; kick off to Next
  601. cn.saveMessage(t, r)
  602. return
  603. case 'T':
  604. // res might be non-nil here if we received a previous
  605. // CommandComplete, but that's fine; just overwrite it
  606. res = &rows{cn: cn}
  607. res.rowsHeader = parsePortalRowDescribe(r)
  608. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  609. // until the first DataRow has been received.
  610. default:
  611. cn.bad = true
  612. errorf("unknown response for simple query: %q", t)
  613. }
  614. }
  615. }
  616. type noRows struct{}
  617. var emptyRows noRows
  618. var _ driver.Result = noRows{}
  619. func (noRows) LastInsertId() (int64, error) {
  620. return 0, errNoLastInsertID
  621. }
  622. func (noRows) RowsAffected() (int64, error) {
  623. return 0, errNoRowsAffected
  624. }
  625. // Decides which column formats to use for a prepared statement. The input is
  626. // an array of type oids, one element per result column.
  627. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
  628. if len(colTyps) == 0 {
  629. return nil, colFmtDataAllText
  630. }
  631. colFmts = make([]format, len(colTyps))
  632. if forceText {
  633. return colFmts, colFmtDataAllText
  634. }
  635. allBinary := true
  636. allText := true
  637. for i, t := range colTyps {
  638. switch t.OID {
  639. // This is the list of types to use binary mode for when receiving them
  640. // through a prepared statement. If a type appears in this list, it
  641. // must also be implemented in binaryDecode in encode.go.
  642. case oid.T_bytea:
  643. fallthrough
  644. case oid.T_int8:
  645. fallthrough
  646. case oid.T_int4:
  647. fallthrough
  648. case oid.T_int2:
  649. fallthrough
  650. case oid.T_uuid:
  651. colFmts[i] = formatBinary
  652. allText = false
  653. default:
  654. allBinary = false
  655. }
  656. }
  657. if allBinary {
  658. return colFmts, colFmtDataAllBinary
  659. } else if allText {
  660. return colFmts, colFmtDataAllText
  661. } else {
  662. colFmtData = make([]byte, 2+len(colFmts)*2)
  663. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  664. for i, v := range colFmts {
  665. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  666. }
  667. return colFmts, colFmtData
  668. }
  669. }
  670. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  671. st := &stmt{cn: cn, name: stmtName}
  672. b := cn.writeBuf('P')
  673. b.string(st.name)
  674. b.string(q)
  675. b.int16(0)
  676. b.next('D')
  677. b.byte('S')
  678. b.string(st.name)
  679. b.next('S')
  680. cn.send(b)
  681. cn.readParseResponse()
  682. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  683. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  684. cn.readReadyForQuery()
  685. return st
  686. }
  687. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  688. if cn.bad {
  689. return nil, driver.ErrBadConn
  690. }
  691. defer cn.errRecover(&err)
  692. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  693. s, err := cn.prepareCopyIn(q)
  694. if err == nil {
  695. cn.inCopy = true
  696. }
  697. return s, err
  698. }
  699. return cn.prepareTo(q, cn.gname()), nil
  700. }
  701. func (cn *conn) Close() (err error) {
  702. // Skip cn.bad return here because we always want to close a connection.
  703. defer cn.errRecover(&err)
  704. // Ensure that cn.c.Close is always run. Since error handling is done with
  705. // panics and cn.errRecover, the Close must be in a defer.
  706. defer func() {
  707. cerr := cn.c.Close()
  708. if err == nil {
  709. err = cerr
  710. }
  711. }()
  712. // Don't go through send(); ListenerConn relies on us not scribbling on the
  713. // scratch buffer of this connection.
  714. return cn.sendSimpleMessage('X')
  715. }
  716. // Implement the "Queryer" interface
  717. func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  718. return cn.query(query, args)
  719. }
  720. func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
  721. if cn.bad {
  722. return nil, driver.ErrBadConn
  723. }
  724. if cn.inCopy {
  725. return nil, errCopyInProgress
  726. }
  727. defer cn.errRecover(&err)
  728. // Check to see if we can use the "simpleQuery" interface, which is
  729. // *much* faster than going through prepare/exec
  730. if len(args) == 0 {
  731. return cn.simpleQuery(query)
  732. }
  733. if cn.binaryParameters {
  734. cn.sendBinaryModeQuery(query, args)
  735. cn.readParseResponse()
  736. cn.readBindResponse()
  737. rows := &rows{cn: cn}
  738. rows.rowsHeader = cn.readPortalDescribeResponse()
  739. cn.postExecuteWorkaround()
  740. return rows, nil
  741. }
  742. st := cn.prepareTo(query, "")
  743. st.exec(args)
  744. return &rows{
  745. cn: cn,
  746. rowsHeader: st.rowsHeader,
  747. }, nil
  748. }
  749. // Implement the optional "Execer" interface for one-shot queries
  750. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  751. if cn.bad {
  752. return nil, driver.ErrBadConn
  753. }
  754. defer cn.errRecover(&err)
  755. // Check to see if we can use the "simpleExec" interface, which is
  756. // *much* faster than going through prepare/exec
  757. if len(args) == 0 {
  758. // ignore commandTag, our caller doesn't care
  759. r, _, err := cn.simpleExec(query)
  760. return r, err
  761. }
  762. if cn.binaryParameters {
  763. cn.sendBinaryModeQuery(query, args)
  764. cn.readParseResponse()
  765. cn.readBindResponse()
  766. cn.readPortalDescribeResponse()
  767. cn.postExecuteWorkaround()
  768. res, _, err = cn.readExecuteResponse("Execute")
  769. return res, err
  770. }
  771. // Use the unnamed statement to defer planning until bind
  772. // time, or else value-based selectivity estimates cannot be
  773. // used.
  774. st := cn.prepareTo(query, "")
  775. r, err := st.Exec(args)
  776. if err != nil {
  777. panic(err)
  778. }
  779. return r, err
  780. }
  781. func (cn *conn) send(m *writeBuf) {
  782. _, err := cn.c.Write(m.wrap())
  783. if err != nil {
  784. panic(err)
  785. }
  786. }
  787. func (cn *conn) sendStartupPacket(m *writeBuf) error {
  788. _, err := cn.c.Write((m.wrap())[1:])
  789. return err
  790. }
  791. // Send a message of type typ to the server on the other end of cn. The
  792. // message should have no payload. This method does not use the scratch
  793. // buffer.
  794. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  795. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  796. return err
  797. }
  798. // saveMessage memorizes a message and its buffer in the conn struct.
  799. // recvMessage will then return these values on the next call to it. This
  800. // method is useful in cases where you have to see what the next message is
  801. // going to be (e.g. to see whether it's an error or not) but you can't handle
  802. // the message yourself.
  803. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  804. if cn.saveMessageType != 0 {
  805. cn.bad = true
  806. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  807. }
  808. cn.saveMessageType = typ
  809. cn.saveMessageBuffer = *buf
  810. }
  811. // recvMessage receives any message from the backend, or returns an error if
  812. // a problem occurred while reading the message.
  813. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  814. // workaround for a QueryRow bug, see exec
  815. if cn.saveMessageType != 0 {
  816. t := cn.saveMessageType
  817. *r = cn.saveMessageBuffer
  818. cn.saveMessageType = 0
  819. cn.saveMessageBuffer = nil
  820. return t, nil
  821. }
  822. x := cn.scratch[:5]
  823. _, err := io.ReadFull(cn.buf, x)
  824. if err != nil {
  825. return 0, err
  826. }
  827. // read the type and length of the message that follows
  828. t := x[0]
  829. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  830. var y []byte
  831. if n <= len(cn.scratch) {
  832. y = cn.scratch[:n]
  833. } else {
  834. y = make([]byte, n)
  835. }
  836. _, err = io.ReadFull(cn.buf, y)
  837. if err != nil {
  838. return 0, err
  839. }
  840. *r = y
  841. return t, nil
  842. }
  843. // recv receives a message from the backend, but if an error happened while
  844. // reading the message or the received message was an ErrorResponse, it panics.
  845. // NoticeResponses are ignored. This function should generally be used only
  846. // during the startup sequence.
  847. func (cn *conn) recv() (t byte, r *readBuf) {
  848. for {
  849. var err error
  850. r = &readBuf{}
  851. t, err = cn.recvMessage(r)
  852. if err != nil {
  853. panic(err)
  854. }
  855. switch t {
  856. case 'E':
  857. panic(parseError(r))
  858. case 'N':
  859. if n := cn.noticeHandler; n != nil {
  860. n(parseError(r))
  861. }
  862. default:
  863. return
  864. }
  865. }
  866. }
  867. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  868. // the caller to avoid an allocation.
  869. func (cn *conn) recv1Buf(r *readBuf) byte {
  870. for {
  871. t, err := cn.recvMessage(r)
  872. if err != nil {
  873. panic(err)
  874. }
  875. switch t {
  876. case 'A':
  877. // ignore
  878. case 'N':
  879. if n := cn.noticeHandler; n != nil {
  880. n(parseError(r))
  881. }
  882. case 'S':
  883. cn.processParameterStatus(r)
  884. default:
  885. return t
  886. }
  887. }
  888. }
  889. // recv1 receives a message from the backend, panicking if an error occurs
  890. // while attempting to read it. All asynchronous messages are ignored, with
  891. // the exception of ErrorResponse.
  892. func (cn *conn) recv1() (t byte, r *readBuf) {
  893. r = &readBuf{}
  894. t = cn.recv1Buf(r)
  895. return t, r
  896. }
  897. func (cn *conn) ssl(o values) error {
  898. upgrade, err := ssl(o)
  899. if err != nil {
  900. return err
  901. }
  902. if upgrade == nil {
  903. // Nothing to do
  904. return nil
  905. }
  906. w := cn.writeBuf(0)
  907. w.int32(80877103)
  908. if err = cn.sendStartupPacket(w); err != nil {
  909. return err
  910. }
  911. b := cn.scratch[:1]
  912. _, err = io.ReadFull(cn.c, b)
  913. if err != nil {
  914. return err
  915. }
  916. if b[0] != 'S' {
  917. return ErrSSLNotSupported
  918. }
  919. cn.c, err = upgrade(cn.c)
  920. return err
  921. }
  922. // isDriverSetting returns true iff a setting is purely for configuring the
  923. // driver's options and should not be sent to the server in the connection
  924. // startup packet.
  925. func isDriverSetting(key string) bool {
  926. switch key {
  927. case "host", "port":
  928. return true
  929. case "password":
  930. return true
  931. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  932. return true
  933. case "fallback_application_name":
  934. return true
  935. case "connect_timeout":
  936. return true
  937. case "disable_prepared_binary_result":
  938. return true
  939. case "binary_parameters":
  940. return true
  941. default:
  942. return false
  943. }
  944. }
  945. func (cn *conn) startup(o values) {
  946. w := cn.writeBuf(0)
  947. w.int32(196608)
  948. // Send the backend the name of the database we want to connect to, and the
  949. // user we want to connect as. Additionally, we send over any run-time
  950. // parameters potentially included in the connection string. If the server
  951. // doesn't recognize any of them, it will reply with an error.
  952. for k, v := range o {
  953. if isDriverSetting(k) {
  954. // skip options which can't be run-time parameters
  955. continue
  956. }
  957. // The protocol requires us to supply the database name as "database"
  958. // instead of "dbname".
  959. if k == "dbname" {
  960. k = "database"
  961. }
  962. w.string(k)
  963. w.string(v)
  964. }
  965. w.string("")
  966. if err := cn.sendStartupPacket(w); err != nil {
  967. panic(err)
  968. }
  969. for {
  970. t, r := cn.recv()
  971. switch t {
  972. case 'K':
  973. cn.processBackendKeyData(r)
  974. case 'S':
  975. cn.processParameterStatus(r)
  976. case 'R':
  977. cn.auth(r, o)
  978. case 'Z':
  979. cn.processReadyForQuery(r)
  980. return
  981. default:
  982. errorf("unknown response for startup: %q", t)
  983. }
  984. }
  985. }
  986. func (cn *conn) auth(r *readBuf, o values) {
  987. switch code := r.int32(); code {
  988. case 0:
  989. // OK
  990. case 3:
  991. w := cn.writeBuf('p')
  992. w.string(o["password"])
  993. cn.send(w)
  994. t, r := cn.recv()
  995. if t != 'R' {
  996. errorf("unexpected password response: %q", t)
  997. }
  998. if r.int32() != 0 {
  999. errorf("unexpected authentication response: %q", t)
  1000. }
  1001. case 5:
  1002. s := string(r.next(4))
  1003. w := cn.writeBuf('p')
  1004. w.string("md5" + md5s(md5s(o["password"] + strings.ToUpper(o["user"]))+s))
  1005. cn.send(w)
  1006. t, r := cn.recv()
  1007. if t != 'R' {
  1008. errorf("unexpected password response: %q", t)
  1009. }
  1010. if r.int32() != 0 {
  1011. errorf("unexpected authentication response: %q", t)
  1012. }
  1013. case 10:
  1014. sc := scram.NewClient(sha256.New, o["user"], o["password"])
  1015. sc.Step(nil)
  1016. if sc.Err() != nil {
  1017. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1018. }
  1019. scOut := sc.Out()
  1020. w := cn.writeBuf('p')
  1021. w.string("SCRAM-SHA-256")
  1022. w.int32(len(scOut))
  1023. w.bytes(scOut)
  1024. cn.send(w)
  1025. t, r := cn.recv()
  1026. if t != 'R' {
  1027. errorf("unexpected password response: %q", t)
  1028. }
  1029. if r.int32() != 11 {
  1030. errorf("unexpected authentication response: %q", t)
  1031. }
  1032. nextStep := r.next(len(*r))
  1033. sc.Step(nextStep)
  1034. if sc.Err() != nil {
  1035. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1036. }
  1037. scOut = sc.Out()
  1038. w = cn.writeBuf('p')
  1039. w.bytes(scOut)
  1040. cn.send(w)
  1041. t, r = cn.recv()
  1042. if t != 'R' {
  1043. errorf("unexpected password response: %q", t)
  1044. }
  1045. if r.int32() != 12 {
  1046. errorf("unexpected authentication response: %q", t)
  1047. }
  1048. nextStep = r.next(len(*r))
  1049. sc.Step(nextStep)
  1050. if sc.Err() != nil {
  1051. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1052. }
  1053. default:
  1054. errorf("unknown authentication response: %d", code)
  1055. }
  1056. }
  1057. type format int
  1058. const formatText format = 0
  1059. const formatBinary format = 1
  1060. // One result-column format code with the value 1 (i.e. all binary).
  1061. var colFmtDataAllBinary = []byte{0, 1, 0, 1}
  1062. // No result-column format codes (i.e. all text).
  1063. var colFmtDataAllText = []byte{0, 0}
  1064. type stmt struct {
  1065. cn *conn
  1066. name string
  1067. rowsHeader
  1068. colFmtData []byte
  1069. paramTyps []oid.Oid
  1070. closed bool
  1071. }
  1072. func (st *stmt) Close() (err error) {
  1073. if st.closed {
  1074. return nil
  1075. }
  1076. if st.cn.bad {
  1077. return driver.ErrBadConn
  1078. }
  1079. defer st.cn.errRecover(&err)
  1080. w := st.cn.writeBuf('C')
  1081. w.byte('S')
  1082. w.string(st.name)
  1083. st.cn.send(w)
  1084. st.cn.send(st.cn.writeBuf('S'))
  1085. t, _ := st.cn.recv1()
  1086. if t != '3' {
  1087. st.cn.bad = true
  1088. errorf("unexpected close response: %q", t)
  1089. }
  1090. st.closed = true
  1091. t, r := st.cn.recv1()
  1092. if t != 'Z' {
  1093. st.cn.bad = true
  1094. errorf("expected ready for query, but got: %q", t)
  1095. }
  1096. st.cn.processReadyForQuery(r)
  1097. return nil
  1098. }
  1099. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1100. if st.cn.bad {
  1101. return nil, driver.ErrBadConn
  1102. }
  1103. defer st.cn.errRecover(&err)
  1104. st.exec(v)
  1105. return &rows{
  1106. cn: st.cn,
  1107. rowsHeader: st.rowsHeader,
  1108. }, nil
  1109. }
  1110. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1111. if st.cn.bad {
  1112. return nil, driver.ErrBadConn
  1113. }
  1114. defer st.cn.errRecover(&err)
  1115. st.exec(v)
  1116. res, _, err = st.cn.readExecuteResponse("simple query")
  1117. return res, err
  1118. }
  1119. func (st *stmt) exec(v []driver.Value) {
  1120. if len(v) >= 65536 {
  1121. errorf("got %d parameters but Kingbase only supports 65535 parameters", len(v))
  1122. }
  1123. if len(v) != len(st.paramTyps) {
  1124. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1125. }
  1126. cn := st.cn
  1127. w := cn.writeBuf('B')
  1128. w.byte(0) // unnamed portal
  1129. w.string(st.name)
  1130. if cn.binaryParameters {
  1131. cn.sendBinaryParameters(w, v)
  1132. } else {
  1133. w.int16(0)
  1134. w.int16(len(v))
  1135. for i, x := range v {
  1136. if x == nil {
  1137. w.int32(-1)
  1138. } else {
  1139. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1140. w.int32(len(b))
  1141. w.bytes(b)
  1142. }
  1143. }
  1144. }
  1145. w.bytes(st.colFmtData)
  1146. w.next('E')
  1147. w.byte(0)
  1148. w.int32(0)
  1149. w.next('S')
  1150. cn.send(w)
  1151. cn.readBindResponse()
  1152. cn.postExecuteWorkaround()
  1153. }
  1154. func (st *stmt) NumInput() int {
  1155. return len(st.paramTyps)
  1156. }
  1157. // parseComplete parses the "command tag" from a CommandComplete message, and
  1158. // returns the number of rows affected (if applicable) and a string
  1159. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1160. // command tag could not be parsed, parseComplete panics.
  1161. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1162. commandsWithAffectedRows := []string{
  1163. "SELECT ",
  1164. // INSERT is handled below
  1165. "UPDATE ",
  1166. "DELETE ",
  1167. "FETCH ",
  1168. "MOVE ",
  1169. "COPY ",
  1170. }
  1171. var affectedRows *string
  1172. for _, tag := range commandsWithAffectedRows {
  1173. if strings.HasPrefix(commandTag, tag) {
  1174. t := commandTag[len(tag):]
  1175. affectedRows = &t
  1176. commandTag = tag[:len(tag)-1]
  1177. break
  1178. }
  1179. }
  1180. // INSERT also includes the oid of the inserted row in its command tag.
  1181. // Oids in user tables are deprecated, and the oid is only returned when
  1182. // exactly one row is inserted, so it's unlikely to be of value to any
  1183. // real-world application and we can ignore it.
  1184. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1185. parts := strings.Split(commandTag, " ")
  1186. if len(parts) != 3 {
  1187. cn.bad = true
  1188. errorf("unexpected INSERT command tag %s", commandTag)
  1189. }
  1190. affectedRows = &parts[len(parts)-1]
  1191. commandTag = "INSERT"
  1192. }
  1193. // There should be no affected rows attached to the tag, just return it
  1194. if affectedRows == nil {
  1195. return driver.RowsAffected(0), commandTag
  1196. }
  1197. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1198. if err != nil {
  1199. cn.bad = true
  1200. errorf("could not parse commandTag: %s", err)
  1201. }
  1202. return driver.RowsAffected(n), commandTag
  1203. }
  1204. type rowsHeader struct {
  1205. colNames []string
  1206. colTyps []fieldDesc
  1207. colFmts []format
  1208. }
  1209. type rows struct {
  1210. cn *conn
  1211. finish func()
  1212. rowsHeader
  1213. done bool
  1214. rb readBuf
  1215. result driver.Result
  1216. tag string
  1217. next *rowsHeader
  1218. }
  1219. func (rs *rows) Close() error {
  1220. if finish := rs.finish; finish != nil {
  1221. defer finish()
  1222. }
  1223. // no need to look at cn.bad as Next() will
  1224. for {
  1225. err := rs.Next(nil)
  1226. switch err {
  1227. case nil:
  1228. case io.EOF:
  1229. // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
  1230. // description, used with HasNextResultSet). We need to fetch messages until
  1231. // we hit a 'Z', which is done by waiting for done to be set.
  1232. if rs.done {
  1233. return nil
  1234. }
  1235. default:
  1236. return err
  1237. }
  1238. }
  1239. }
  1240. func (rs *rows) Columns() []string {
  1241. return rs.colNames
  1242. }
  1243. func (rs *rows) Result() driver.Result {
  1244. if rs.result == nil {
  1245. return emptyRows
  1246. }
  1247. return rs.result
  1248. }
  1249. func (rs *rows) Tag() string {
  1250. return rs.tag
  1251. }
  1252. func (rs *rows) Next(dest []driver.Value) (err error) {
  1253. if rs.done {
  1254. return io.EOF
  1255. }
  1256. conn := rs.cn
  1257. if conn.bad {
  1258. return driver.ErrBadConn
  1259. }
  1260. defer conn.errRecover(&err)
  1261. for {
  1262. t := conn.recv1Buf(&rs.rb)
  1263. switch t {
  1264. case 'E':
  1265. err = parseError(&rs.rb)
  1266. case 'C', 'I':
  1267. if t == 'C' {
  1268. rs.result, rs.tag = conn.parseComplete(rs.rb.string())
  1269. }
  1270. continue
  1271. case 'Z':
  1272. conn.processReadyForQuery(&rs.rb)
  1273. rs.done = true
  1274. if err != nil {
  1275. return err
  1276. }
  1277. return io.EOF
  1278. case 'D':
  1279. n := rs.rb.int16()
  1280. if err != nil {
  1281. conn.bad = true
  1282. errorf("unexpected DataRow after error %s", err)
  1283. }
  1284. if n < len(dest) {
  1285. dest = dest[:n]
  1286. }
  1287. for i := range dest {
  1288. l := rs.rb.int32()
  1289. if l == -1 {
  1290. dest[i] = nil
  1291. continue
  1292. }
  1293. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
  1294. }
  1295. return
  1296. case 'T':
  1297. next := parsePortalRowDescribe(&rs.rb)
  1298. rs.next = &next
  1299. return io.EOF
  1300. default:
  1301. errorf("unexpected message after execute: %q", t)
  1302. }
  1303. }
  1304. }
  1305. func (rs *rows) HasNextResultSet() bool {
  1306. hasNext := rs.next != nil && !rs.done
  1307. return hasNext
  1308. }
  1309. func (rs *rows) NextResultSet() error {
  1310. if rs.next == nil {
  1311. return io.EOF
  1312. }
  1313. rs.rowsHeader = *rs.next
  1314. rs.next = nil
  1315. return nil
  1316. }
  1317. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1318. // used as part of an SQL statement. For example:
  1319. //
  1320. // tblname := "my_table"
  1321. // data := "my_data"
  1322. // quoted := kb.QuoteIdentifier(tblname)
  1323. // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
  1324. //
  1325. // Any double quotes in name will be escaped. The quoted identifier will be
  1326. // case sensitive when used in a query. If the input string contains a zero
  1327. // byte, the result will be truncated immediately before it.
  1328. func QuoteIdentifier(name string) string {
  1329. end := strings.IndexRune(name, 0)
  1330. if end > -1 {
  1331. name = name[:end]
  1332. }
  1333. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1334. }
  1335. // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
  1336. // to DDL and other statements that do not accept parameters) to be used as part
  1337. // of an SQL statement. For example:
  1338. //
  1339. // exp_date := kb.QuoteLiteral("2023-01-05 15:00:00Z")
  1340. // err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
  1341. //
  1342. // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
  1343. // replaced by two backslashes (i.e. "\\") and the C-style escape identifier
  1344. // that Kingbase provides ('E') will be prepended to the string.
  1345. func QuoteLiteral(literal string) string {
  1346. // This follows the Kingbase internal algorithm for handling quoted literals
  1347. // from libkci, which can be found in the "PQEscapeStringInternal" function,
  1348. // which is found in the libkci/fe-exec.c source file:
  1349. // https://git.kingbase.org/gitweb/?p=kingbase.git;a=blob;f=src/interfaces/libkci/fe-exec.c
  1350. //
  1351. // substitute any single-quotes (') with two single-quotes ('')
  1352. literal = strings.Replace(literal, `'`, `''`, -1)
  1353. // determine if the string has any backslashes (\) in it.
  1354. // if it does, replace any backslashes (\) with two backslashes (\\)
  1355. // then, we need to wrap the entire string with a Kingbase
  1356. // C-style escape. Per how "PQEscapeStringInternal" handles this case, we
  1357. // also add a space before the "E"
  1358. if strings.Contains(literal, `\`) {
  1359. literal = strings.Replace(literal, `\`, `\\`, -1)
  1360. literal = ` E'` + literal + `'`
  1361. } else {
  1362. // otherwise, we can just wrap the literal with a pair of single quotes
  1363. literal = `'` + literal + `'`
  1364. }
  1365. return literal
  1366. }
  1367. func md5s(s string) string {
  1368. h := md5.New()
  1369. h.Write([]byte(s))
  1370. return fmt.Sprintf("%x", h.Sum(nil))
  1371. }
  1372. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1373. // Do one pass over the parameters to see if we're going to send any of
  1374. // them over in binary. If we are, create a paramFormats array at the
  1375. // same time.
  1376. var paramFormats []int
  1377. for i, x := range args {
  1378. _, ok := x.([]byte)
  1379. if ok {
  1380. if paramFormats == nil {
  1381. paramFormats = make([]int, len(args))
  1382. }
  1383. paramFormats[i] = 1
  1384. }
  1385. }
  1386. if paramFormats == nil {
  1387. b.int16(0)
  1388. } else {
  1389. b.int16(len(paramFormats))
  1390. for _, x := range paramFormats {
  1391. b.int16(x)
  1392. }
  1393. }
  1394. b.int16(len(args))
  1395. for _, x := range args {
  1396. if x == nil {
  1397. b.int32(-1)
  1398. } else {
  1399. datum := binaryEncode(&cn.parameterStatus, x)
  1400. b.int32(len(datum))
  1401. b.bytes(datum)
  1402. }
  1403. }
  1404. }
  1405. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1406. if len(args) >= 65536 {
  1407. errorf("got %d parameters but Kingbase only supports 65535 parameters", len(args))
  1408. }
  1409. b := cn.writeBuf('P')
  1410. b.byte(0) // unnamed statement
  1411. b.string(query)
  1412. b.int16(0)
  1413. b.next('B')
  1414. b.int16(0) // unnamed portal and statement
  1415. cn.sendBinaryParameters(b, args)
  1416. b.bytes(colFmtDataAllText)
  1417. b.next('D')
  1418. b.byte('P')
  1419. b.byte(0) // unnamed portal
  1420. b.next('E')
  1421. b.byte(0)
  1422. b.int32(0)
  1423. b.next('S')
  1424. cn.send(b)
  1425. }
  1426. func (cn *conn) processParameterStatus(r *readBuf) {
  1427. var err error
  1428. param := r.string()
  1429. switch param {
  1430. case "server_version":
  1431. var major1 int
  1432. var major2 int
  1433. _, err = fmt.Sscanf(r.string(), "V%dR%d", &major1, &major2)
  1434. if err == nil {
  1435. cn.parameterStatus.serverVersion = major1*10000 + major2*100
  1436. }
  1437. case "TimeZone":
  1438. cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1439. if err != nil {
  1440. cn.parameterStatus.currentLocation = nil
  1441. }
  1442. default:
  1443. // ignore
  1444. }
  1445. }
  1446. func (cn *conn) processReadyForQuery(r *readBuf) {
  1447. cn.txnStatus = transactionStatus(r.byte())
  1448. }
  1449. func (cn *conn) readReadyForQuery() {
  1450. t, r := cn.recv1()
  1451. switch t {
  1452. case 'Z':
  1453. cn.processReadyForQuery(r)
  1454. return
  1455. default:
  1456. cn.bad = true
  1457. errorf("unexpected message %q; expected ReadyForQuery", t)
  1458. }
  1459. }
  1460. func (cn *conn) processBackendKeyData(r *readBuf) {
  1461. cn.processID = r.int32()
  1462. cn.secretKey = r.int32()
  1463. }
  1464. func (cn *conn) readParseResponse() {
  1465. t, r := cn.recv1()
  1466. switch t {
  1467. case '1':
  1468. return
  1469. case 'E':
  1470. err := parseError(r)
  1471. cn.readReadyForQuery()
  1472. panic(err)
  1473. default:
  1474. cn.bad = true
  1475. errorf("unexpected Parse response %q", t)
  1476. }
  1477. }
  1478. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
  1479. for {
  1480. t, r := cn.recv1()
  1481. switch t {
  1482. case 't':
  1483. nparams := r.int16()
  1484. paramTyps = make([]oid.Oid, nparams)
  1485. for i := range paramTyps {
  1486. paramTyps[i] = r.oid()
  1487. }
  1488. case 'n':
  1489. return paramTyps, nil, nil
  1490. case 'T':
  1491. colNames, colTyps = parseStatementRowDescribe(r)
  1492. return paramTyps, colNames, colTyps
  1493. case 'E':
  1494. err := parseError(r)
  1495. cn.readReadyForQuery()
  1496. panic(err)
  1497. default:
  1498. cn.bad = true
  1499. errorf("unexpected Describe statement response %q", t)
  1500. }
  1501. }
  1502. }
  1503. func (cn *conn) readPortalDescribeResponse() rowsHeader {
  1504. t, r := cn.recv1()
  1505. switch t {
  1506. case 'T':
  1507. return parsePortalRowDescribe(r)
  1508. case 'n':
  1509. return rowsHeader{}
  1510. case 'E':
  1511. err := parseError(r)
  1512. cn.readReadyForQuery()
  1513. panic(err)
  1514. default:
  1515. cn.bad = true
  1516. errorf("unexpected Describe response %q", t)
  1517. }
  1518. panic("not reached")
  1519. }
  1520. func (cn *conn) readBindResponse() {
  1521. t, r := cn.recv1()
  1522. switch t {
  1523. case '2':
  1524. return
  1525. case 'E':
  1526. err := parseError(r)
  1527. cn.readReadyForQuery()
  1528. panic(err)
  1529. default:
  1530. cn.bad = true
  1531. errorf("unexpected Bind response %q", t)
  1532. }
  1533. }
  1534. func (cn *conn) postExecuteWorkaround() {
  1535. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1536. // any errors from rows.Next, which masks errors that happened during the
  1537. // execution of the query. To avoid the problem in common cases, we wait
  1538. // here for one more message from the database. If it's not an error the
  1539. // query will likely succeed (or perhaps has already, if it's a
  1540. // CommandComplete), so we push the message into the conn struct; recv1
  1541. // will return it as the next message for rows.Next or rows.Close.
  1542. // However, if it's an error, we wait until ReadyForQuery and then return
  1543. // the error to our caller.
  1544. for {
  1545. t, r := cn.recv1()
  1546. switch t {
  1547. case 'E':
  1548. err := parseError(r)
  1549. cn.readReadyForQuery()
  1550. panic(err)
  1551. case 'C', 'D', 'I':
  1552. // the query didn't fail, but we can't process this message
  1553. cn.saveMessage(t, r)
  1554. return
  1555. default:
  1556. cn.bad = true
  1557. errorf("unexpected message during extended query execution: %q", t)
  1558. }
  1559. }
  1560. }
  1561. // Only for Exec(), since we ignore the returned data
  1562. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1563. for {
  1564. t, r := cn.recv1()
  1565. switch t {
  1566. case 'C':
  1567. if err != nil {
  1568. cn.bad = true
  1569. errorf("unexpected CommandComplete after error %s", err)
  1570. }
  1571. res, commandTag = cn.parseComplete(r.string())
  1572. case 'Z':
  1573. cn.processReadyForQuery(r)
  1574. if res == nil && err == nil {
  1575. err = errUnexpectedReady
  1576. }
  1577. return res, commandTag, err
  1578. case 'E':
  1579. err = parseError(r)
  1580. case 'T', 'D', 'I':
  1581. if err != nil {
  1582. cn.bad = true
  1583. errorf("unexpected %q after error %s", t, err)
  1584. }
  1585. if t == 'I' {
  1586. res = emptyRows
  1587. }
  1588. // ignore any results
  1589. default:
  1590. cn.bad = true
  1591. errorf("unknown %s response: %q", protocolState, t)
  1592. }
  1593. }
  1594. }
  1595. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
  1596. n := r.int16()
  1597. colNames = make([]string, n)
  1598. colTyps = make([]fieldDesc, n)
  1599. for i := range colNames {
  1600. colNames[i] = r.string()
  1601. r.next(6)
  1602. colTyps[i].OID = r.oid()
  1603. colTyps[i].Len = r.int16()
  1604. colTyps[i].Mod = r.int32()
  1605. // format code not known when describing a statement; always 0
  1606. r.next(2)
  1607. }
  1608. return
  1609. }
  1610. func parsePortalRowDescribe(r *readBuf) rowsHeader {
  1611. n := r.int16()
  1612. colNames := make([]string, n)
  1613. colFmts := make([]format, n)
  1614. colTyps := make([]fieldDesc, n)
  1615. for i := range colNames {
  1616. colNames[i] = r.string()
  1617. r.next(6)
  1618. colTyps[i].OID = r.oid()
  1619. colTyps[i].Len = r.int16()
  1620. colTyps[i].Mod = r.int32()
  1621. colFmts[i] = format(r.int16())
  1622. }
  1623. return rowsHeader{
  1624. colNames: colNames,
  1625. colFmts: colFmts,
  1626. colTyps: colTyps,
  1627. }
  1628. }
  1629. // parseEnviron tries to mimic some of libkci's environment handling
  1630. //
  1631. // To ease testing, it does not directly reference os.Environ, but is
  1632. // designed to accept its output.
  1633. //
  1634. // Environment-set connection information is intended to have a higher
  1635. // precedence than a library default but lower than any explicitly
  1636. // passed information (such as in the URL or connection string).
  1637. func parseEnviron(env []string) (out map[string]string) {
  1638. out = make(map[string]string)
  1639. for _, v := range env {
  1640. parts := strings.SplitN(v, "=", 2)
  1641. accrue := func(keyname string) {
  1642. out[keyname] = parts[1]
  1643. }
  1644. unsupported := func() {
  1645. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1646. }
  1647. // The order of these is the same as is seen in the
  1648. // Kingbase 9.1 manual. Unsupported but well-defined
  1649. // keys cause a panic; these should be unset prior to
  1650. // execution. Options which kb expects to be set to a
  1651. // certain value are allowed, but must be set to that
  1652. // value if present (they can, of course, be absent).
  1653. switch parts[0] {
  1654. case "KBHOST":
  1655. accrue("host")
  1656. case "KBHOSTADDR":
  1657. unsupported()
  1658. case "KBPORT":
  1659. accrue("port")
  1660. case "KBDATABASE":
  1661. accrue("dbname")
  1662. case "KBUSER":
  1663. accrue("user")
  1664. case "KBPASSWORD":
  1665. accrue("password")
  1666. case "KBSERVICE", "KBSERVICEFILE", "KBREALM":
  1667. unsupported()
  1668. case "KBOPTIONS":
  1669. accrue("options")
  1670. case "KBAPPNAME":
  1671. accrue("application_name")
  1672. case "KBSSLMODE":
  1673. accrue("sslmode")
  1674. case "KBSSLCERT":
  1675. accrue("sslcert")
  1676. case "KBSSLKEY":
  1677. accrue("sslkey")
  1678. case "KBSSLROOTCERT":
  1679. accrue("sslrootcert")
  1680. case "KBREQUIRESSL", "KBSSLCRL":
  1681. unsupported()
  1682. case "KBREQUIREPEER":
  1683. unsupported()
  1684. case "KBKRBSRVNAME", "KBGSSLIB":
  1685. unsupported()
  1686. case "KBCONNECT_TIMEOUT":
  1687. accrue("connect_timeout")
  1688. case "KBCLIENTENCODING":
  1689. accrue("client_encoding")
  1690. case "KBDATESTYLE":
  1691. accrue("datestyle")
  1692. case "KBTZ":
  1693. accrue("timezone")
  1694. case "KBGEQO":
  1695. accrue("geqo")
  1696. case "KBSYSCONFDIR", "KBLOCALEDIR":
  1697. unsupported()
  1698. }
  1699. }
  1700. return out
  1701. }
  1702. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1703. func isUTF8(name string) bool {
  1704. // Recognize all sorts of silly things as "UTF-8", like Kingbase does
  1705. s := strings.Map(alnumLowerASCII, name)
  1706. return s == "utf8" || s == "unicode"
  1707. }
  1708. func alnumLowerASCII(ch rune) rune {
  1709. if 'A' <= ch && ch <= 'Z' {
  1710. return ch + ('a' - 'A')
  1711. }
  1712. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1713. return ch
  1714. }
  1715. return -1 // discard
  1716. }