You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
4.5 KiB

  1. ###################################################################################
  2. #
  3. # Copyright (C) 2017 MuK IT GmbH
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU Affero General Public License as
  7. # published by the Free Software Foundation, either version 3 of the
  8. # License, or (at your option) any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU Affero General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU Affero General Public License
  16. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. #
  18. ###################################################################################
  19. import json
  20. import logging
  21. import psycopg2
  22. import functools
  23. from contextlib import closing
  24. from werkzeug.contrib.sessions import SessionStore
  25. from odoo.sql_db import db_connect
  26. from odoo.tools import config
  27. _logger = logging.getLogger(__name__)
  28. def ensure_cursor(func):
  29. @functools.wraps(func)
  30. def wrapper(self, *args, **kwargs):
  31. for attempts in range(1, 6):
  32. try:
  33. return func(self, *args, **kwargs)
  34. except psycopg2.InterfaceError as error:
  35. _logger.info("SessionStore connection failed! (%s/5)" % attempts)
  36. if attempts < 5:
  37. self._open_connection()
  38. else:
  39. raise error
  40. return wrapper
  41. class PostgresSessionStore(SessionStore):
  42. def __init__(self, *args, **kwargs):
  43. super(PostgresSessionStore, self).__init__(*args, **kwargs)
  44. self.dbname = config.get('session_store_dbname', 'session_store')
  45. self._open_connection()
  46. self._setup_db()
  47. def _create_database(self):
  48. with closing(db_connect("postgres").cursor()) as cursor:
  49. cursor.autocommit(True)
  50. cursor.execute("""
  51. CREATE DATABASE {dbname}
  52. ENCODING 'unicode'
  53. TEMPLATE 'template0';
  54. """.format(dbname=self.dbname))
  55. self._setup_db()
  56. def _open_connection(self, create_db=True):
  57. try:
  58. connection = db_connect(self.dbname, allow_uri=True)
  59. self.cursor = connection.cursor()
  60. self.cursor.autocommit(True)
  61. except:
  62. if not create_db:
  63. raise
  64. self._create_database()
  65. return self._open_connection(create_db=False)
  66. @ensure_cursor
  67. def _setup_db(self):
  68. self.cursor.execute("""
  69. CREATE TABLE IF NOT EXISTS sessions (
  70. sid varchar PRIMARY KEY,
  71. write_date timestamp without time zone NOT NULL,
  72. payload text NOT NULL
  73. );
  74. """)
  75. @ensure_cursor
  76. def save(self, session):
  77. self.cursor.execute("""
  78. INSERT INTO sessions (sid, write_date, payload)
  79. VALUES (%(sid)s, now() at time zone 'UTC', %(payload)s)
  80. ON CONFLICT (sid)
  81. DO UPDATE SET payload = %(payload)s, write_date = now() at time zone 'UTC';
  82. """, dict(sid=session.sid, payload=json.dumps(dict(session))))
  83. @ensure_cursor
  84. def delete(self, session):
  85. self.cursor.execute("DELETE FROM sessions WHERE sid=%s;", [session.sid])
  86. @ensure_cursor
  87. def get(self, sid):
  88. if not self.is_valid_key(sid):
  89. return self.new()
  90. self.cursor.execute("""
  91. SELECT payload, write_date
  92. FROM sessions WHERE sid=%s;
  93. """, [sid])
  94. try:
  95. payload, write_date = self.cursor.fetchone()
  96. if write_date.date() != datetime.today().date():
  97. self.cursor.execute("""
  98. UPDATE sessions
  99. SET write_date = now() at time zone 'UTC'
  100. WHERE sid=%s;
  101. """, [sid])
  102. return self.session_class(json.loads(payload), sid, False)
  103. except Exception:
  104. return self.session_class({}, sid, False)
  105. @ensure_cursor
  106. def list(self):
  107. self.cursor.execute("SELECT sid FROM sessions;")
  108. return [record[0] for record in self.cursor.fetchall()]
  109. @ensure_cursor
  110. def clean(self):
  111. self.cursor.execute("""
  112. DELETE FROM sessions
  113. WHERE now() at time zone 'UTC' - write_date > '7 days';
  114. """)