# SPDX-FileCopyrightText: 2019 Guillermo Rodriguez
#
# SPDX-License-Identifier: AGPL-3.0-or-later

# -*- coding: utf-8 -*-

import os
import unittest

import intelmq.lib.test as test
import intelmq.lib.utils as utils
from intelmq.bots.parsers.shadowserver.parser import ShadowserverParserBot

with open(os.path.join(os.path.dirname(__file__),
                       'testdata/test_smb.csv')) as handle:
    EXAMPLE_FILE = handle.read()
EXAMPLE_LINES = EXAMPLE_FILE.splitlines()

EXAMPLE_REPORT = {"raw": utils.base64_encode(EXAMPLE_FILE),
                  "__type": "Report",
                  "time.observation": "2018-07-30T00:00:00+00:00",
                  "extra.file_name": "2019-01-01-test_smb-test-test.csv",
                  'feed.name': 'report feedname',
                  }
EVENTS = [{
    '__type': 'Event',
    'feed.name': 'report feedname',
    "classification.identifier": 'test-smb',
    "classification.taxonomy": "vulnerable",
    "classification.type": "vulnerable-system",
    "extra.smb_implant": False,
    "extra.smb_major_number": '2',
    "extra.smb_minor_number": '1',
    "extra.smb_version_string": 'SMB 2.1',
    "extra.smbv1_support": 'N',
    "extra.tag": "smb",
    "protocol.application": "smb",
    "protocol.transport": "tcp",
    'raw': utils.base64_encode('\n'.join([EXAMPLE_LINES[0],
                                         EXAMPLE_LINES[1]])),
    "source.asn": 64512,
    "source.geolocation.cc": "ZZ",
    "source.geolocation.city": "City",
    "source.geolocation.region": "Region",
    "source.ip": "192.168.0.1",
    "source.port": 445,
    "source.reverse_dns": "node01.example.com",
    "time.observation": "2018-07-30T00:00:00+00:00",
    "time.source": "2010-02-10T00:00:00+00:00"
},
]


class TestShadowserverParserBot(test.BotTestCase, unittest.TestCase):
    """
    A TestCase for a ShadowserverParserBot.
    """

    @classmethod
    def set_bot(cls):
        cls.bot_reference = ShadowserverParserBot
        cls.default_input_message = EXAMPLE_REPORT

    def test_default(self):
        """ Test if feed name is not overwritten has been produced. """
        self.prepare_bot(parameters={'test_mode': True})
        self.run_bot()
        for i, EVENT in enumerate(EVENTS):
            self.assertMessageEqual(i, EVENT)

    def test_overwrite_feed_name(self):
        """ Test if feed name is overwritten if asked to do so. """
        self.prepare_bot(parameters={'test_mode': True, 'overwrite': True})
        self.run_bot(prepare=False)
        for i, EVENT in enumerate(EVENTS):
            event = EVENT.copy()
            event['feed.name'] = 'Test-Accessible-SMB'
            self.assertMessageEqual(i, event)


if __name__ == '__main__':  # pragma: no cover
    unittest.main()
