Skip to content

Commit d9dceac

Browse files
Fix Oracle Provider different count query and feature query when a sql_manipulator is used, Issue #1831 (#1909)
* Adjusted Query Function to include SQL Manipulation in the Count Query for numberMatched retrieval. Now there queries should be the same in any case. Added a function for less repetitions and added a Test. * Made process_query_with_sql_manipulator_sup more concise and removed duplications
1 parent 51c6a95 commit d9dceac

2 files changed

Lines changed: 74 additions & 32 deletions

File tree

pygeoapi/provider/oracle.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,38 @@ def _get_srid_from_crs(self, crs):
613613

614614
return srid
615615

616+
def _process_query_with_sql_manipulator_sup(
617+
self, db, sql_query, bind_variables, extra_params, **query_args
618+
):
619+
"""
620+
Apply the SQL manipulation plugin to process the SQL query.
621+
622+
:param db: Database connection instance
623+
:param sql_query: The SQL query to process
624+
:param bind_variables: Query bind variables
625+
:param extra_params: Additional parameters for manipulation
626+
:param query_args: Other dynamic arguments required for processing
627+
:return: Processed SQL query and bind variables
628+
"""
629+
if self.sql_manipulator:
630+
LOGGER.debug(f"sql_manipulator: {self.sql_manipulator}")
631+
manipulation_class = _class_factory(self.sql_manipulator)
632+
633+
# Pass all arguments to the process_query method
634+
sql_query, bind_variables = manipulation_class.process_query(
635+
db=db,
636+
sql_query=sql_query,
637+
bind_variables=bind_variables,
638+
sql_manipulator_options=self.sql_manipulator_options,
639+
**query_args,
640+
extra_params=extra_params,
641+
)
642+
643+
for placeholder in ["#HINTS#", "#JOIN#", "#WHERE#"]:
644+
sql_query = sql_query.replace(placeholder, "")
645+
646+
return sql_query, bind_variables
647+
616648
def query(
617649
self,
618650
offset=0,
@@ -695,9 +727,36 @@ def query(
695727
# because of getFields ...
696728
sql_query = f"SELECT COUNT(1) AS hits \
697729
FROM {self.table} \
698-
{where_dict['clause']}"
730+
{where_dict['clause']} #WHERE#"
731+
732+
# Assign where_dict["properties"] to bind_variables
733+
bind_variables = {**where_dict["properties"]}
734+
735+
# Default values for the process_query function (sql_manipulator)
736+
query_args = {
737+
"offset": offset,
738+
"limit": limit,
739+
"resulttype": resulttype,
740+
"bbox": bbox,
741+
"datetime_": datetime_,
742+
"properties": properties,
743+
"sortby": sortby,
744+
"skip_geometry": skip_geometry,
745+
"select_properties": select_properties,
746+
"crs_transform_spec": crs_transform_spec,
747+
"q": q,
748+
"language": language,
749+
"filterq": filterq,
750+
}
751+
752+
# Apply the SQL manipulation plugin
753+
extra_params["geom"] = self.geom
754+
sql_query, bind_variables = self._process_query_with_sql_manipulator_sup( # noqa: E501
755+
db, sql_query, bind_variables, extra_params, **query_args
756+
)
757+
699758
try:
700-
cursor.execute(sql_query, where_dict["properties"])
759+
cursor.execute(sql_query, bind_variables)
701760
except oracledb.Error as err:
702761
LOGGER.error(
703762
f"Error executing sql_query: {sql_query}: {err}"
@@ -795,36 +854,10 @@ def query(
795854
# Create dictionary for sql bind variables
796855
bind_variables = {**where_dict["properties"], **paging_bind}
797856

798-
# SQL manipulation plugin
799-
if self.sql_manipulator:
800-
LOGGER.debug("sql_manipulator: " + self.sql_manipulator)
801-
manipulation_class = _class_factory(self.sql_manipulator)
802-
sql_query, bind_variables = manipulation_class.process_query(
803-
db,
804-
sql_query,
805-
bind_variables,
806-
self.sql_manipulator_options,
807-
offset,
808-
limit,
809-
resulttype,
810-
bbox,
811-
datetime_,
812-
properties,
813-
sortby,
814-
skip_geometry,
815-
select_properties,
816-
crs_transform_spec,
817-
q,
818-
language,
819-
filterq,
820-
extra_params=extra_params
821-
)
822-
823-
# Clean up placeholders that aren't used by the
824-
# manipulation class.
825-
sql_query = sql_query.replace("#HINTS#", "")
826-
sql_query = sql_query.replace("#JOIN#", "")
827-
sql_query = sql_query.replace("#WHERE#", "")
857+
# Apply the SQL manipulation plugin
858+
sql_query, bind_variables = self._process_query_with_sql_manipulator_sup( # noqa: E501
859+
db, sql_query, bind_variables, extra_params, **query_args
860+
)
828861

829862
LOGGER.debug(f"SQL Query: {sql_query}")
830863
LOGGER.debug(f"Bind variables: {bind_variables}")

tests/test_oracle_provider.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def process_query(
7070

7171
if sql_query.find(" WHERE ") == -1:
7272
sql_query = sql_query.replace("#WHERE#", f" WHERE {sql}")
73+
7374
else:
7475
sql_query = sql_query.replace("#WHERE#", f" AND {sql}")
7576

@@ -644,6 +645,14 @@ def test_extra_params_are_passed_to_sql_manipulator(config_manipulator):
644645
assert not response['features']
645646

646647

648+
def test_query_count_sql_manipulator(config_manipulator):
649+
"""Test query number of hits"""
650+
p = OracleProvider(config_manipulator)
651+
result = p.query(resulttype="hits")
652+
653+
assert result.get("numberMatched") == 1
654+
655+
647656
@pytest.fixture()
648657
def database_connection_pool(config_db_conn):
649658
os.environ["ORACLE_POOL_MIN"] = "2" # noqa: F841

0 commit comments

Comments
 (0)