Как стать автором
Обновить

Комментарии 31

рефакторинг должен быть усилен юнит/интеграционными тестами, а это значит мы должны полностью представлять логику работы функции. Без тестов по типу сожри вот эту фиг пойми какую бяку и дай мне нормальный текст даст результат убедительного текста похожего на код, но это не равно правильный код.

Кстати, один из нормальных кейсов для GPT - "напиши мне тесты по этой функции". Смотрим глазами; если все ок, добавляем в тестсьют.

Как по мне, самое лучшее у GPT- "напиши мне функцию, делает то то и то то". Функций написали, потом написали тестов или вручную (что все равно быстрее), затем пишем условный монолит "Есть задача.....есть функции. По сути промт как над-язык основного ЯП. И тут тоже можно двигаться по функциям, добавляя их в программу поочередно и тестируя выход.

Что касается функций, я нашел некий метод, заставлющую ИИ вылизывать функцию. Делается это на любом языке автоматизации, я делал вообще на VistaRunner. Суть скрипта проста как топор : после появления оранжевой кнопки "Generate" скрипт считает, что ИИ завершил работу и ждет новых указаний. Именно в этот момент скрипт шлет в текстовое поле чата рандомную строку из текстового файла и жмет ввод. И ждет дальше доступность кнопки generate. Содержимое текстового файла: Тут ошибка исправь (повторить 10 раз) затем уточнения по промту, часть промта повторяем сюда, например сортировка по турнирному алгоритму. Неважно, что напишет моделька, скрипт слепой. Но после порядка 20 таких итераций на выходе получаем или дистиллированный код, или запросы модельки "да что тебе от меня надо? Я уже и так и сяк." Для этого в том же текстовом файле иногда вставляем рандомную критику, или похвалу "я заметил, что память тут расходуется слишком сильно" или наоборот - "хороший код, но я бы назвал переменные более осмысленно" - льем общие слова, моделька пахает.

Психология)

Вот из-за таких вот затейников взбесившийся ИИ и захочет уничтожить человечество.

Ии не человек, ему не больно, не скучно, и не страшно, нет никаких оснований думать что ему захочется кого то уничтожить.

Делают ли "электровцы" "саморефакторинг" своей нейросетки? )

Дешевле нанять одного программиста из Индии... )

Это не работает от «совсем». Без контекста тест будет синтетическим ради теста, а описывать контекст в промпте зачастую на порядки сложнее, чем написать юниты самостоятельно.

Пока контекстное окно моделей будет у трепанированной золотой рыбки, вся эта херь с убийством программистов работать не будет. вообще. точка.

Специализированными моделями не пользовался, пробовал для кодинга ChatGPT, и для задач типа "напиши код, который делает то-то" примерно в 30% случаях получал рабочий код. Примерно треть случаев - код не работал сразу, как надо, но логика была понятна, и после ручного рефакторинга использовать было можно (хотя, ценность, конечно, резко падает). В остальных случаях нейросеть фигню какую-то выдавала.

Но попробовал ставить задачу под конкретный фреймворк - Laravel, и тут дела гораздо лучше пошли. Описал сущности, их свойства, и получил код миграций, моделей и контроллера. Потом попросил дать код blade-шаблонов для CRUD и получил их. Прикольно, что ГПТ запоминает контекст, и когда попросил добавить доп.поле, то получил в ответ и миграцию, и обновленные модель с контроллером.

В итоге получил небольшое работающее приложение, где я не написал ни строчки кода, был сплошной копипаст, хотя, и не бездумный - местами копипаст был фрагментарный.

А вот GigaChat от Сбера тупеньким оказался, он сильно старался, но код был совсем нерабочий.

А вот интересно, можно ли делать TDD с ИИ? Пишешь тесты, а ИИ пусть пишет к ним код и проверяет его на тестах.

Там в тексте есть характерный пример: когда-то эту деталь токарь руками точил, а сейчас станок-автомат её делает на раз сотнями за смену и с лучшим качеством, нужно только работнику подучиться малость 😂. Типа, это хорошо и это даже не подвергается сомнению. Да, это несомненно, если таких деталей нужны сотни за смену и неохота платить зарплату квалифицированному токарю седьмого разряда... Тогда и станок дорогущий можно купить, и специалистам по обслуживанию этого станка платить... ИИ - это не интеллект. Это торговая марка. Здесь можно трактаты на эту тему писать - но неохота, ибо лень. А "искину" не лень, поэтому скоро мы все будем завалены информационным мусором, порождаемым "by AI", по самое горло. Если не спохватимся, конечно.

Но в реальности станок, который выдает преемлимый результат хорошо если один раз из двух просто выбросили бы на помойку.

Ни одна маленькая модель размером 8B-13B, например, Code-Llama-13B или Nxcode-CQ-7B-orpo (она же усовершенствованная CodeQwen1.5-7B) или deepseek-coder-6.7b ни разу у меня не выдала удовлетворительного результата в более-менее серьезных задачах по кодированию. Этих размеров явно недостаточно, лучше даже не тратьте своё время и силы. Codestral хотя бы имеет 22B весов и это минимум для того, чтобы иметь практическую ценность. Следовательно, специально обученные для программирования модели большего размера будет вероятно еще более продвинутыми и полезными, с весами 34B или 70B или более и тут уж всё зависит только от вашего железа, если вы работаете локально.

А можно вот это в понятные числа перевести? Сколько памяти нужно, какая видеокарта и процессор?

Вторая причина неудач - это сам запрос, промпт (точнее их последовательность), который должен быть составлен грамотно и профессионально.

Тут может оказаться, что быстрее написать сам код, чем последовательность промптов с анализом результатов.

На сайте ollama можно посмотреть примерно сколько надо для запуска моделей на своем компьютере. https://ollama.com/

С реальными задачами всё очень сложно. Банально просим сделать коллаж из фоток, пишем что надо укладывать по столько то картинок но чтоб не шире 2000 получилось. ИИ делает, смотришь и понимаешь что даже из 2 картинок может получится шире 2000. Или даже одной. И что тогда? Говоришь что в таком случае вот так надо, а в таком вот эдак, получается уже что то настолько сложное что даже просто понять что происходит трудно. Куча нюансов выплывает на пустом месте которые ты не ожидал вообще, а уж ИИ и подавно.

Видишь такое и понимаешь что не понимаешь даже как это проконтролировать, а ведь так просто начиналось.

Я честно перешел по вашей ссылке и так и не понял, какое железо мне купить. Того, что у вас на скрине, у меня на десктопе нет 0_о А даже если бы и были, это гигабайты на винте, в ОЗУ или видеопамяти? Чего покупать-то?

Это размер на диске. Скорее всего эти файлы тупо отображаются в память. В видеопамять если есть, в обычную если нет. Что бы гонять 40гбайтную модель с приличной скоростью надо 2 видеокарты по 24гб памяти у каждой Ж)

А можно вот это в понятные числа перевести? Сколько памяти нужно

В названии модели обычно указано количество весов, например, 70B обозначает 70 миллиардов (B - billion). Каждый вес храниться в виде числа, которое в зависимости от квантизации модели занимает определенное число бит. Например, квантизация fp16 хранит веса в виде чисел с плавающей запятой половиной точности, таким образом на один вес приходится 16 бит данных или 2 байта, следовательно вся модель будет весить 140 миллиардов байт или 140Гб. Для работы такой модели потребуется 140Гб оперативной памяти и ещё сколько-то для хранения контекста. Если понизить точность весов, можно существенно уменьшить размер модели, при этом до какого-то предела качество её ответов будет страдать не так сильно. Это собственно и есть процесс квантизации. Самому это делать не нужно, модели в нужной квантизации как правило можно скачать. Так при квантизации 4 бита на вес модель уже будет занимать 35Гб (+память под контекст), что конечно все равно много для потребительских видеокарт, но можно погонять на CPU (медленно, ~1 токен/с). Таким образом для видеокарт с 12-16Гб VRAM выбор модели и квантизации это все равно компромисс, а в случае с запуском на CPU скорость работы оставляет желать лучшего.

Спасибо, стало понятнее. Я просто где-то читал, что важна и видеопамять и обычная. Наверное, что-то не так понял.

Однозначно второе нет, суть в том, что просты условно типовые, а код каждый раз уникальный. Да и текстом большинству людей легче общаться. Данная статья направлена на то ,что ллм могут заменить низкоуровневых специалистов, но на серьезные вещи пока не способна. Безусловно столь детерминированная область как код будет замещена машиной, вопрос времени, надеюсь тут никто спорить не будет. Вопрос времени.

Безусловно столь детерминированная область как код будет замещена машиной, вопрос времени, надеюсь тут никто спорить не будет. Вопрос времени.

Код, на основании четко посталенного и разжёванного ТЗ может писать чуть ли не секретарша, она быстро текст набирает. Программирование != написание кода. А что касается вопроса времени, "до промышленного термояда осталось 30 лет" (с)

Именно, что нет, писать код не может просто образованный или начитанный человек. Если очень подробно и разжевать задание, но без навыков и знания синтаксиса, даже биолог или маркетолог не сможет справиться с задачей, тем более секретарша. Поэтому ллм обученные на кодовой базе, гиты, парадигмах и книгах , stack overflow, чаты разработчиков и т.д смогут это сделать. Автор корректно указал про объём параметров от 70В, phind решает огромное количество задач. Попробуйте, вы будете сильно удивлены, что необходимо прокачивать доп.навыки, чтобы не стать кожаным мешком)

А можно вот это в понятные числа перевести? Сколько памяти нужно, какая видеокарта и процессор?

У меня модель Codestral 22B с квантизацией 4 бита локально работает на машине с 32 гб RAM и 12 гб VRAM. Процессор Ryzen 5 5600X.

Спасибо, хоть какие-то вменяемые числа.

Не программист. Использовал, использую и буду использовать ии для маленьких задачек типа скриптов для парсинга и какого-то примитивного обсчета, а также для всякой мелочи на fast api.

По наличии знаний о том как поставить задачу гпт, а также некоторого хорошего понимания, что мне нужно и что в коде реализовано криво, зачем же мне самому руками это с нуля пилить?

Резюме: для моих вспомогательных задач гпт крайне хороший помощник. Используют ли его настоящие разработчики для написания программ? Скорее всего :-)

Отличное применение для упрощения жизни и экономии времени!)

Вот она, нейронка моей мечты (с) один счастливый программист

- try:
-     s1 = datetime.strptime(s.replace('24:00','00:00'), '%d/%m/%y %H:%M')
-     if mg:
-         s2 = datetime.strptime(rows2[ind2], '%d/%m/%y %H:%M')
-         if ind2 < len(rows2)-1 and s1 > s2:
-             s = f"Date={s} of {src} exceeded the date={rows2[ind2]} из {src2}"
-             show_text(s,1)
-             exit(0)
- except Exception as err:
-     s = f'Date format error {s}'
-     show_text(s)
-     print(s + " in " + src)
-     exit(0)


+ current_time = parse_time(current_time_str)
+
+ if is_market_data:
+     next_time = datetime.strptime(rows2[index2], '%d/%m/%y %H:%M')
+     if index2 < len(rows2)-1 and current_time > next_time:
+         error_message = f"Date={current_time_str} of {source_file} exceeded the date={rows2[index2]} из {source_file2}"
+         show_text(error_message, 1)
+         exit(0)

Любопытно, как изменилась функция после так называемого рефакторинга. В одном месте прибыло и стало лучше, в другом, наоборот, стало хуже. Я уверен, будь у вас набор тестов для этой функции, вы бы сами увидели, насколько качественным является этот рефакторинг.

Я бы на вашем месте отправил эту нейронку на пересдачу, а не твёрдую 4+ ставил

Вот ревью вашего кода от другой LLM 😂

Рекомендации по улучшению:

  1. Добавьте аннотации типов для улучшения читаемости и облегчения отладки.

  2. Рассмотрите возможность использования dataclass для структурирования данных свечей.

  3. Вместо exit(0) лучше возвращать статус ошибки, позволяя вызывающему коду решать, как обрабатывать ошибки.

  4. Рассмотрите возможность уменьшения зависимости от глобальных переменных, передавая необходимые параметры в функции.

  5. Добавьте документацию (docstrings) к функциям для лучшего понимания их назначения и параметров

  6. (Лично от меня) добавьте типтзацию

И вариант кода

from dataclasses import dataclass
from datetime import datetime
import numpy as np

@dataclass
class CandleData:
    timestamp: datetime
    open_price: float
    high_price: float
    low_price: float
    close_price: float
    volume: float
    is_manual: bool

def load_market_data(data_file_path, start_time, learning_period, history_length):
    manual_data = load_manual_data() if has_manual_data() else []
    market_data = load_and_parse_data(data_file_path, manual_data)
    filtered_data = filter_data(market_data, start_time, learning_period)
    return format_output(filtered_data)

def has_manual_data():
    return 5 in global_context.ohlcv_columns

def load_manual_data():
    manual_file = f"{global_context.ticker}/{global_context.timeframe}/baseline/manual_{global_context.baseline}{global_context.baseline_number}.txt"
    return read_file(manual_file)

def read_file(file_path):
    try:
        with open(file_path, 'r') as file:
            return file.readlines()
    except IOError as e:
        print(f"Error reading file {file_path}: {str(e)}")
        return []

def load_and_parse_data(data_file_path, manual_data):
    data = []
    manual_index = 0
    for line in read_file(data_file_path):
        candle = parse_candle(line, manual_data[manual_index] if manual_index < len(manual_data) else None)
        if candle:
            data.append(candle)
            if candle.is_manual and manual_index < len(manual_data) - 1:
                manual_index += 1
    return data

def parse_candle(line, manual_timestamp):
    parts = line.strip().split(',')
    if len(parts) != 7:
        print(f"Invalid data format: {line}")
        return None

    timestamp = parse_timestamp(f"{parts[0]} {parts[1]}")
    if not timestamp:
        return None

    price_mult = 0.01 if global_context.ticker == 'GMK' and timestamp <= datetime(2021, 3, 24) else 1.0

    try:
        return CandleData(
            timestamp=timestamp,
            open_price=float(parts[2]) * price_mult,
            high_price=float(parts[3]) * price_mult,
            low_price=float(parts[4]) * price_mult,
            close_price=float(parts[5]) * price_mult,
            volume=float(parts[6]),
            is_manual=manual_timestamp and manual_timestamp.strip() == f"{parts[0]} {parts[1]}"
        )
    except ValueError:
        print(f"Invalid numeric data: {line}")
        return None

def parse_timestamp(timestamp_str):
    try:
        return datetime.strptime(timestamp_str.replace('24:00', '00:00'), '%d/%m/%y %H:%M')
    except ValueError:
        print(f"Invalid timestamp format: {timestamp_str}")
        return None

def filter_data(data, start_time, learning_period):
    start = datetime.strptime(start_time.replace('24:00', '00:00'), '%d/%m/%y %H:%M')
    return [candle for candle in data if candle.timestamp >= start][:learning_period]

def format_output(data):
    timestamps = [candle.timestamp.strftime('%d/%m/%y %H:%M') for candle in data]
    close_prices = np.array([candle.close_price for candle in data])
    full_data = np.array([[candle.open_price, candle.high_price, candle.low_price, 
                           candle.close

Или

from dataclasses import dataclass
from datetime import datetime
from typing import List, Tuple, Optional, Union
import numpy as np

@dataclass
class CandleData:
    timestamp: datetime
    open_price: float
    high_price: float
    low_price: float
    close_price: float
    volume: float
    is_manual: bool

def load_market_data(data_file_path: str, start_time: str, learning_period: int, history_length: int) -> Tuple[bool, Union[Tuple[List[str], np.ndarray, np.ndarray], str]]:
    """
    Загружает и обрабатывает рыночные данные из файла.
    """
    global_context.current_index = 0

    manual_timestamps = load_manual_data() if has_manual_data() else []
    success, result = read_file(data_file_path)
    if not success:
        return False, result

    market_data_rows = result
    success, result = process_market_data(market_data_rows, manual_timestamps)
    if not success:
        return False, result

    market_data = result
    success, result = filter_and_limit_data(market_data, start_time, learning_period)
    if not success:
        return False, result

    analyzed_data = result
    return format_output(analyzed_data)

def has_manual_data() -> bool:
    """Проверяет, есть ли ручные данные."""
    return 5 in global_context.ohlcv_columns

def load_manual_data() -> List[str]:
    """Загружает ручные данные из файла."""
    manual_data_path = f"{global_context.ticker}/{global_context.timeframe}/baseline/manual_{global_context.baseline}{global_context.baseline_number}.txt"
    success, result = read_file(manual_data_path)
    return result if success else []

def read_file(file_path: str) -> Tuple[bool, Union[List[str], str]]:
    """Читает данные из файла."""
    try:
        with open(file_path, 'r') as file:
            return True, file.readlines()
    except IOError as e:
        return False, f"Error reading file {file_path}: {str(e)}"

def process_market_data(market_data_rows: List[str], manual_timestamps: List[str]) -> Tuple[bool, Union[List[CandleData], str]]:
    """Обрабатывает строки рыночных данных."""
    market_data = []
    manual_data_index = 0

    for row in market_data_rows:
        current_manual_timestamp = get_manual_timestamp(manual_timestamps, manual_data_index)
        success, result = parse_market_data(row, current_manual_timestamp)
        if not success:
            return False, result
        
        candle_data = result
        market_data.append(candle_data)
        
        if candle_data.is_manual and manual_data_index < len(manual_timestamps) - 1:
            manual_data_index += 1

    return True, market_data

def get_manual_timestamp(manual_timestamps: List[str], index: int) -> Optional[str]:
    """Получает ручную метку времени, если она доступна."""
    return manual_timestamps[index].strip() if index < len(manual_timestamps) else None

def parse_market_data(row: str, manual_timestamp: Optional[str] = None) -> Tuple[bool, Union[CandleData, str]]:
    """Разбирает строку рыночных данных."""
    data_elements = row.strip().split(",")
    if len(data_elements) != 7:
        return False, f"Invalid data format in row: {row}"

    timestamp_str = f"{data_elements[0]} {data_elements[1]}"
    timestamp = parse_timestamp(timestamp_str
                                тут что-то потеряно

А вот ещё

Пожалуй, лучший вариант

Хотя, я Шарп больше люблю. Да и исключения не люблю.

from dataclasses import dataclass
from datetime import datetime
from typing import List, Tuple, Optional
import csv

@dataclass
class DataPoint:
    date: datetime
    open: float
    high: float
    low: float
    close: float
    volume: float
    is_manual: bool

class DataLoader:
    TICKER = "GMK"
    OHLCV = [1, 2, 3, 4, 5]
    TFRAME = "1H"
    BSL = "1"
    BS_N = "1"

    @staticmethod
    def load_data(source_path: str, start_time: str, learning_period: int, normalization_factor: int) -> Tuple[List[str], List[float], List[DataPoint]]:
        try:
            rows = DataLoader.read_file(source_path)
            manual_rows = DataLoader.read_file(DataLoader.get_manual_source_path()) if DataLoader.has_manual_data() else []
            data_points = DataLoader.parse_data_points(rows, manual_rows)
            relevant_data = DataLoader.filter_relevant_data(data_points, start_time, learning_period)
            return DataLoader.format_output(relevant_data)
        except Exception as e:
            print(f"Error: {str(e)}")
            return [], [], []

    @staticmethod
    def has_manual_data() -> bool:
        return 5 in DataLoader.OHLCV

    @staticmethod
    def get_manual_source_path() -> str:
        return f"{DataLoader.TICKER}/{DataLoader.TFRAME}/bsl/manual_{DataLoader.BSL}{DataLoader.BS_N}.txt"

    @staticmethod
    def read_file(path: str) -> List[List[str]]:
        with open(path, 'r') as file:
            return list(csv.reader(file))

    @staticmethod
    def parse_data_points(rows: List[List[str]], manual_rows: List[List[str]]) -> List[DataPoint]:
        data_points = []
        for row, manual_row in zip(rows, manual_rows + [None] * (len(rows) - len(manual_rows))):
            data_point = DataLoader.parse_row(row, manual_row)
            if data_point:
                data_points.append(data_point)
        return data_points

    @staticmethod
    def parse_row(row: List[str], manual_row: Optional[List[str]]) -> Optional[DataPoint]:
        if len(row) != 7:
            print(f"Invalid row format: {row}")
            return None

        try:
            date_time = DataLoader.parse_datetime(row[0], row[1])
            coefficient = DataLoader.calculate_coefficient(date_time)
            values = [float(v) for v in row[2:7]]
            
            return DataPoint(
                date=date_time,
                open=values[0] * coefficient,
                high=values[1] * coefficient,
                low=values[2] * coefficient,
                close=values[3] * coefficient,
                volume=values[4],
                is_manual=DataLoader.is_manual_data_point(date_time, manual_row)
            )
        except ValueError as e:
            print(f"Error parsing row {row}: {str(e)}")
            return None

    @staticmethod
    def parse_datetime(date: str, time: str) -> datetime:
        return datetime.strptime(f"{date} {time}", "%d/%m/%y %H:%M")

    @staticmethod
    def calculate_coefficient(date_time: datetime) -> float:
        return 0.01 if DataLoader.TICKER == "GMK" and date_time <= datetime(2021, 3, 24) else 1.0

    @staticmethod
    def is_manual_data_point(date_time: datetime, manual_row: Optional[List[str]]) -> bool:
        return manual_row is not None and manual_row[0] == date_time.strftime("%d/%m/%y %H:%M")

    @staticmethod
    def filter_relevant_data(data_points: List[DataPoint], start_time: str, learning_period: int) -> List[DataPoint]:
        start_datetime = DataLoader.parse_datetime(start_time[:8], start_time[9:14])
        relevant_data = [dp for dp in data_points if dp.date >= start_datetime][:learning_period]
        
        if len(relevant_data) < learning_period:
            print(f"Warning: Not enough data points after {start_time}")
        
        return relevant_data

    @staticmethod
    def format_output(data_points: List[DataPoint]) -> Tuple[List[str], List[float], List[DataPoint]]:
        return (
            [dp.date.strftime("%d/%m/%y %H:%M") for dp in data_points],
            [dp.close for dp in data_points],
            data_points
        )

# Пример использования
if __name__ == "__main__":
    dates, closes, data_points = DataLoader.load_data("data.csv", "01/01/21 00:00", 1000, 0)
    if dates and closes and data_points:
        print(f"Loaded {len(dates)} data points")
        print(f"First date: {dates[0]}, Last date: {dates[-1]}")
        print(f"First close: {closes[0]}, Last close: {closes[-1]}")
    else:
        print("Failed to load data")

# Тесты
import unittest
import tempfile
import os

class TestDataLoader(unittest.TestCase):
    def test_load_data_valid_input(self):
        with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file:
            temp_file.write("01/01/21,00:00,100,101,99,100.5,1000\n"
                            "01/01/21,01:00,100.5,102,100,101.5,1500\n")
            temp_file_path = temp_file.name

        dates, closes, data_points = DataLoader.load_data(temp_file_path, "01/01/21 00:00", 2, 0)
        
        self.assertEqual(len(dates), 2)
        self.assertEqual(dates[0], "01/01/21 00:00")
        self.assertEqual(dates[1], "01/01/21 01:00")
        self.assertEqual(closes[0], 100.5)
        self.assertEqual(closes[1], 101.5)

        os.unlink(temp_file_path)

    def test_load_data_invalid_file(self):
        dates, closes, data_points = DataLoader.load_data("non_existent_file.csv", "01/01/21 00:00", 1000, 0)
        self.assertEqual(len(dates), 0)
        self.assertEqual(len(closes), 0)
        self.assertEqual(len(data_points), 0)

if __name__ == '__main__':
    unittest.main()
```

Эта версия кода:

1. Использует более традиционный для Python подход с обработкой исключений.
2. Упрощает структуру кода, делая его более линейным и легким для чтения.
3. Использует опциональные типы (Optional) для обработки отсутствующих данных.
4. Выводит предупреждения и ошибки в консоль вместо использования монады Result.
5. Сохраняет основную функциональность, но делает код более понятным для Python-раз

И второй код после рефакторинга с моим код стайлом

from typing import List, Union
from tensorflow.keras.layers import Input, Dense, Dropout, Conv1D, MaxPooling1D, Flatten, SimpleRNN, GRU, LSTM, concatenate, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
import numpy as np

class GlobalParams:
    pass

def create_neural_network_model(g: GlobalParams) -> Model:
    if len(g.ohlcv) == 1:
        return _create_single_input_model(g)
    else:
        return _create_multi_input_model(g)

def _create_single_input_model(g: GlobalParams) -> Model:
    input_layer = _create_input_layer(g)
    x = _add_hidden_layers(g, input_layer)
    output_layer = _create_output_layer(g, x)
    return Model(input_layer, output_layer)

def _create_multi_input_model(g: GlobalParams) -> Model:
    input_layers: List[Input] = []
    intermediate_outputs: List[Layer] = []
    
    for _ in range(len(g.ohlcv)):
        input_layer = _create_input_layer(g)
        x = _add_hidden_layers(g, input_layer, is_multi_input=True)
        input_layers.append(input_layer)
        intermediate_outputs.append(x)
    
    combined = concatenate(intermediate_outputs)
    x = _add_remaining_layers(g, combined)
    output_layer = _create_output_layer(g, x)
    
    return Model(input_layers, output_layer)

def _create_input_layer(g: GlobalParams) -> Input:
    if g.type_model == 0:
        return Input((g.nr0,))
    elif g.type_model == 1:
        return Input((g.nr0, 1))
    elif g.type_model in [2, 3, 4]:
        return Input((1, g.nr0))

def _add_hidden_layers(g: GlobalParams, x: Layer, is_multi_input: bool = False) -> Layer:
    if not g.nr_leyers:
        return _add_default_layers(g, x)
    else:
        return _add_custom_layers(g, x, is_multi_input)

def _add_default_layers(g: GlobalParams, x: Layer) -> Layer:
    neurons = g.nr0
    for _ in range(100):
        neurons = int(neurons / 2)
        if neurons < 10:
            break
        x = Dense(neurons, activation=g.act_list_model, kernel_initializer=g.init_list_model)(x)
        x = Dropout(g.slider2.val if g.slider2.val != g.slider2.valinit else g.dropout)(x)
    return x

def _add_custom_layers(g: GlobalParams, x: Layer, is_multi_input: bool) -> Layer:
    flat = 0
    dropout_count = 0
    rnn_types = ['simple_rnn', 'gru', 'lstm']
    rnn_layers = [SimpleRNN, GRU, LSTM]

    for layer in g.nr_leyers:
        layer_type = layer[0]
        
        if layer_type == 'dense':
            x = _add_dense_layer(g, x, layer[1], flat)
            flat = 1 if g.type_model == 1 and flat == 0 else flat
        elif layer_type == 'dense_l2':
            x = _add_dense_l2_layer(g, x, layer[1], layer[2])
        elif layer_type == 'dropout':
            x = _add_dropout_layer(g, x, layer[1])
            dropout_count += 1
            if is_multi_input and dropout_count == g.dropout_cut:
                break
        elif layer_type == 'conv1d':
            x = _add_conv1d_layer(g, x, layer[1], layer[2])
        elif layer_type == 'max_pooling1d':
            x = MaxPooling1D(layer[1])(x)
        elif layer_type in rnn_types:
            x = _add_rnn_layer(g, x, layer, rnn_types, rnn_layers, flat)
            flat += 1
    
    return x

def _add_dense_layer(g: GlobalParams, x: Layer, units: int, flat: int) -> Layer:
    if g.type_model == 1 and flat == 0:
        x = Flatten()(x)
    return Dense(units, activation=g.act_list_model, kernel_initializer=g.init_list_model, bias_initializer=g.init_list_model)(x)

def _add_dense_l2_layer(g: GlobalParams, x: Layer, units: int, l2_rate: float) -> Layer:
    return Dense(units, activation=g.act_list_model, kernel_initializer=g.init_list_model, bias_initializer=g.init_list_model, 
                 kernel_regularizer=l2(l2_rate), bias_regularizer=l2(l2_rate))(x)

def _add_dropout_layer(g: GlobalParams, x: Layer, rate: float) -> Layer:
    dropout_rate = g.slider2.val if g.slider2.val != g.slider2.valinit else rate
    return Dropout(dropout_rate)(x)

def _add_conv1d_layer(g: GlobalParams, x: Layer, filters: int, kernel_size: int) -> Layer:
    return Conv1D(filters, kernel_size, strides=1, padding='same', activation=g.act_list_model, kernel_initializer=g.init_list_model)(x)

def _add_rnn_layer(g: GlobalParams, x: Layer, layer: List[Union[str, int]], rnn_types: List[str], rnn_layers: List[type], flat: int) -> Layer:
    rnn_index = rnn_types.index(layer[0])
    return_sequences = flat < sum(1 for l in g.nr_leyers if l[0] in rnn_types) - 1
    return rnn_layers[rnn_index](layer[1], activation=g.act_list_model, recurrent_dropout=g.dropout, 
                                 kernel_initializer=g.init_list_model, recurrent_initializer=g.init_list_model, 
                                 bias_initializer=g.init_list_model, return_sequences=return_sequences)(x)

def _add_remaining_layers(g: GlobalParams, x: Layer) -> Layer:
    if not g.nr_leyers:
        return _add_remaining_default_layers(g, x)
    else:
        return _add_remaining_custom_layers(g, x)

def _add_remaining_default_layers(g: GlobalParams, x: Layer) -> Layer:
    neurons = g.nr0
    for _ in range(100):
        neurons = int(neurons / 2)
        if neurons < 10:
            break
        x = Dense(neurons, activation=g.act_list_model, kernel_initializer=g.init_list_model)(x)
        x = Dropout(g.slider2.val if g.slider2.val != g.slider2.valinit else g.dropout)(x)
    return x

def _add_remaining_custom_layers(g: GlobalParams, x: Layer) -> Layer:
    return _add_custom_layers(g, x, is_multi_input=False)

def _create_output_layer(g: GlobalParams, x: Layer) -> Layer:
    return Dense(2 if g.bsl == "bsl" else 1, activation=g.act_list_model, kernel_initializer=g

                 ... что-то там ...

Ну как вам такой code style?

Зарегистрируйтесь на Хабре, чтобы оставить комментарий

Публикации

Истории