You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
4.6 KiB

  1. ###################################################################################
  2. #
  3. # Copyright (C) 2017 MuK IT GmbH
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU Affero General Public License as
  7. # published by the Free Software Foundation, either version 3 of the
  8. # License, or (at your option) any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU Affero General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU Affero General Public License
  16. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. #
  18. ###################################################################################
  19. import json
  20. import logging
  21. import psycopg2
  22. import functools
  23. from contextlib import closing
  24. from datetime import datetime, date
  25. from werkzeug.contrib.sessions import SessionStore
  26. from odoo.sql_db import db_connect
  27. from odoo.tools import config
  28. _logger = logging.getLogger(__name__)
  29. def ensure_cursor(func):
  30. @functools.wraps(func)
  31. def wrapper(self, *args, **kwargs):
  32. for attempts in range(1, 6):
  33. try:
  34. return func(self, *args, **kwargs)
  35. except psycopg2.InterfaceError as error:
  36. _logger.info("SessionStore connection failed! (%s/5)" % attempts)
  37. if attempts < 5:
  38. self._open_connection()
  39. else:
  40. raise error
  41. return wrapper
  42. class PostgresSessionStore(SessionStore):
  43. def __init__(self, *args, **kwargs):
  44. super(PostgresSessionStore, self).__init__(*args, **kwargs)
  45. self.dbname = config.get('session_store_dbname', 'session_store')
  46. self._open_connection()
  47. self._setup_db()
  48. def _create_database(self):
  49. with closing(db_connect("postgres").cursor()) as cursor:
  50. cursor.autocommit(True)
  51. cursor.execute("""
  52. CREATE DATABASE {dbname}
  53. ENCODING 'unicode'
  54. TEMPLATE 'template0';
  55. """.format(dbname=self.dbname))
  56. self._setup_db()
  57. def _open_connection(self, create_db=True):
  58. try:
  59. connection = db_connect(self.dbname, allow_uri=True)
  60. self.cursor = connection.cursor()
  61. self.cursor.autocommit(True)
  62. except:
  63. if not create_db:
  64. raise
  65. self._create_database()
  66. return self._open_connection(create_db=False)
  67. @ensure_cursor
  68. def _setup_db(self):
  69. self.cursor.execute("""
  70. CREATE TABLE IF NOT EXISTS sessions (
  71. sid varchar PRIMARY KEY,
  72. write_date timestamp without time zone NOT NULL,
  73. payload text NOT NULL
  74. );
  75. """)
  76. @ensure_cursor
  77. def save(self, session):
  78. self.cursor.execute("""
  79. INSERT INTO sessions (sid, write_date, payload)
  80. VALUES (%(sid)s, now() at time zone 'UTC', %(payload)s)
  81. ON CONFLICT (sid)
  82. DO UPDATE SET payload = %(payload)s, write_date = now() at time zone 'UTC';
  83. """, dict(sid=session.sid, payload=json.dumps(dict(session))))
  84. @ensure_cursor
  85. def delete(self, session):
  86. self.cursor.execute("DELETE FROM sessions WHERE sid=%s;", [session.sid])
  87. @ensure_cursor
  88. def get(self, sid):
  89. if not self.is_valid_key(sid):
  90. return self.new()
  91. self.cursor.execute("""
  92. SELECT payload, write_date
  93. FROM sessions WHERE sid=%s;
  94. """, [sid])
  95. try:
  96. payload, write_date = self.cursor.fetchone()
  97. if write_date.date() != datetime.today().date():
  98. self.cursor.execute("""
  99. UPDATE sessions
  100. SET write_date = now() at time zone 'UTC'
  101. WHERE sid=%s;
  102. """, [sid])
  103. return self.session_class(json.loads(payload), sid, False)
  104. except Exception:
  105. return self.session_class({}, sid, False)
  106. @ensure_cursor
  107. def list(self):
  108. self.cursor.execute("SELECT sid FROM sessions;")
  109. return [record[0] for record in self.cursor.fetchall()]
  110. @ensure_cursor
  111. def clean(self):
  112. self.cursor.execute("""
  113. DELETE FROM sessions
  114. WHERE now() at time zone 'UTC' - write_date > '7 days';
  115. """)