|
1 | 1 | """Test the base controller.""" |
2 | 2 |
|
| 3 | +import argparse |
3 | 4 | from unittest.mock import MagicMock, patch |
4 | 5 |
|
5 | 6 | import pytest |
@@ -77,3 +78,114 @@ def test_call_exit(): |
77 | 78 | with patch.object(controller, "save_class", MagicMock()): |
78 | 79 | controller.queue = ["quit"] |
79 | 80 | controller.call_exit(None) |
| 81 | + |
| 82 | + |
| 83 | +@pytest.fixture |
| 84 | +def mock_base_session(): |
| 85 | + """Mock the session for parse_known_args_and_warn tests.""" |
| 86 | + with patch("openbb_cli.controllers.base_controller.session") as mock_session: |
| 87 | + mock_session.settings.USE_CLEAR_AFTER_CMD = False |
| 88 | + yield mock_session |
| 89 | + |
| 90 | + |
| 91 | +def _make_parser(*args_spec): |
| 92 | + """Create an argparse parser from a list of add_argument kwargs.""" |
| 93 | + parser = argparse.ArgumentParser(add_help=False) |
| 94 | + for spec in args_spec: |
| 95 | + flags = spec.pop("flags") |
| 96 | + parser.add_argument(*flags, **spec) |
| 97 | + return parser |
| 98 | + |
| 99 | + |
| 100 | +def test_comma_split_flagged_value_not_split(mock_base_session): |
| 101 | + """Simple test: --symbol AAPL,MSFT must stay as one value.""" |
| 102 | + parser = _make_parser({"flags": ["--symbol", "-s"], "dest": "symbol", "type": str}) |
| 103 | + result = BaseController.parse_known_args_and_warn(parser, ["--symbol", "AAPL,MSFT"]) |
| 104 | + assert result is not None |
| 105 | + assert result.symbol == "AAPL,MSFT" |
| 106 | + |
| 107 | + |
| 108 | +def test_comma_split_short_flag_not_split(mock_base_session): |
| 109 | + """Short flag -s AAPL,MSFT must also stay as one value.""" |
| 110 | + parser = _make_parser({"flags": ["--symbol", "-s"], "dest": "symbol", "type": str}) |
| 111 | + result = BaseController.parse_known_args_and_warn(parser, ["-s", "AAPL,MSFT"]) |
| 112 | + assert result is not None |
| 113 | + assert result.symbol == "AAPL,MSFT" |
| 114 | + |
| 115 | + |
| 116 | +def test_comma_split_equals_syntax_not_split(mock_base_session): |
| 117 | + """--symbol=AAPL,MSFT must not be split.""" |
| 118 | + parser = _make_parser({"flags": ["--symbol", "-s"], "dest": "symbol", "type": str}) |
| 119 | + result = BaseController.parse_known_args_and_warn(parser, ["--symbol=AAPL,MSFT"]) |
| 120 | + assert result is not None |
| 121 | + assert result.symbol == "AAPL,MSFT" |
| 122 | + |
| 123 | + |
| 124 | +def test_comma_split_nargs_plus_all_values_protected(mock_base_session): |
| 125 | + """nargs='+': all consecutive values after --symbols are protected.""" |
| 126 | + parser = _make_parser( |
| 127 | + {"flags": ["--symbols"], "dest": "symbols", "nargs": "+", "type": str} |
| 128 | + ) |
| 129 | + result = BaseController.parse_known_args_and_warn( |
| 130 | + parser, ["--symbols", "AAPL,MSFT", "GOOG,AMZN"] |
| 131 | + ) |
| 132 | + assert result is not None |
| 133 | + assert result.symbols == ["AAPL,MSFT", "GOOG,AMZN"] |
| 134 | + |
| 135 | + |
| 136 | +def test_comma_split_nargs_star_values_protected(mock_base_session): |
| 137 | + """nargs='*': consecutive values after --tags are protected.""" |
| 138 | + parser = _make_parser( |
| 139 | + {"flags": ["--tags"], "dest": "tags", "nargs": "*", "type": str} |
| 140 | + ) |
| 141 | + result = BaseController.parse_known_args_and_warn(parser, ["--tags", "a,b", "c,d"]) |
| 142 | + assert result is not None |
| 143 | + assert result.tags == ["a,b", "c,d"] |
| 144 | + |
| 145 | + |
| 146 | +def test_comma_split_nargs_int_values_protected(mock_base_session): |
| 147 | + """nargs=2: both values after --pair are protected.""" |
| 148 | + parser = _make_parser( |
| 149 | + {"flags": ["--pair"], "dest": "pair", "nargs": 2, "type": str} |
| 150 | + ) |
| 151 | + result = BaseController.parse_known_args_and_warn(parser, ["--pair", "a,b", "c,d"]) |
| 152 | + assert result is not None |
| 153 | + assert result.pair == ["a,b", "c,d"] |
| 154 | + |
| 155 | + |
| 156 | +def test_comma_split_store_true_not_confused(mock_base_session): |
| 157 | + """store_true flags (nargs=0) should not protect the next token.""" |
| 158 | + parser = _make_parser( |
| 159 | + {"flags": ["--symbol", "-s"], "dest": "symbol", "type": str}, |
| 160 | + {"flags": ["--raw"], "dest": "raw", "action": "store_true", "default": False}, |
| 161 | + ) |
| 162 | + result = BaseController.parse_known_args_and_warn( |
| 163 | + parser, ["--raw", "--symbol", "AAPL,MSFT"] |
| 164 | + ) |
| 165 | + assert result is not None |
| 166 | + assert result.raw is True |
| 167 | + assert result.symbol == "AAPL,MSFT" |
| 168 | + |
| 169 | + |
| 170 | +def test_comma_split_no_comma_values_unchanged(mock_base_session): |
| 171 | + """Values without commas pass through unaffected.""" |
| 172 | + parser = _make_parser({"flags": ["--symbol", "-s"], "dest": "symbol", "type": str}) |
| 173 | + result = BaseController.parse_known_args_and_warn(parser, ["--symbol", "AAPL"]) |
| 174 | + assert result is not None |
| 175 | + assert result.symbol == "AAPL" |
| 176 | + |
| 177 | + |
| 178 | +def test_comma_split_multiple_flags_each_protected(mock_base_session): |
| 179 | + """Multiple flags each protect their own values independently.""" |
| 180 | + parser = _make_parser( |
| 181 | + {"flags": ["--symbol", "-s"], "dest": "symbol", "type": str}, |
| 182 | + {"flags": ["--raw"], "dest": "raw", "action": "store_true", "default": False}, |
| 183 | + {"flags": ["--provider"], "dest": "provider", "type": str}, |
| 184 | + ) |
| 185 | + result = BaseController.parse_known_args_and_warn( |
| 186 | + parser, |
| 187 | + ["--symbol", "AAPL,MSFT", "--provider", "yfinance,polygon"], |
| 188 | + ) |
| 189 | + assert result is not None |
| 190 | + assert result.symbol == "AAPL,MSFT" |
| 191 | + assert result.provider == "yfinance,polygon" |
0 commit comments