代码之家  ›  专栏  ›  技术社区  ›  CoccoSyn

Keras TypeError:其中一个输入没有可接受的类型。具有参差不齐张量的LSTM

  •  0
  • CoccoSyn  · 技术社区  · 2 年前

    我正在尝试用参差不齐的张量构建一个LSTM(目标是有一个动作识别模型,可以接受和预测不同长度的动作)。(使用tensorflow 2.6)

    输入的形状如下:<tf。张量:shape=(3,),dtype=int64,numpy=数组([30,10,1662],dtype=int64)>,因此N序列共10组1662个点。

    模型如下所示:

    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=[None, 1662], ragged=True),
        tf.keras.layers.LSTM(64, return_sequences = True, activation = 'tanh'),
        tf.keras.layers.LSTM(128, return_sequences = True, activation = 'tanh'),
        tf.keras.layers.LSTM(64, return_sequences = True, activation = 'tanh'),
        tf.keras.layers.TimeDistributed(Dense(64, activation = 'relu')),
        tf.keras.layers.TimeDistributed(Dense(32, activation = 'relu')),
        tf.keras.layers.TimeDistributed(Dense(1, activation = 'softmax')),
    ])
    

    必须使用密集的时间分布层来匹配参差不齐张量的形状,因此输出也是3D的。

    编译和拟合:

    model.compile(optimizer = 'Adam', loss = 'categorical_crossentropy', metrics = ['categorical_accuracy'])
    model.fit(X_train, y_train, epochs = 4000, callbacks = callb)
    

    此代码引发错误:

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    ~\AppData\Local\Temp\ipykernel_23244\613233899.py in <module>
          1 #problem might be due to the timesdistributed that takes in 3d vectors, while the y labels are 2d ones.. look into that
          2 callb = myCallback()
    ----> 3 model.fit(X_train, y_train, epochs = 4000, callbacks = callb)
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
       1191                 _r=1):
       1192               callbacks.on_train_batch_begin(step)
    -> 1193               tmp_logs = self.train_function(iterator)
       1194               if data_handler.should_sync:
       1195                 context.async_wait()
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
        883 
        884       with OptionalXlaContext(self._jit_compile):
    --> 885         result = self._call(*args, **kwds)
        886 
        887       new_tracing_count = self.experimental_get_tracing_count()
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
        931       # This is the first call of __call__, so we have to initialize.
        932       initializers = []
    --> 933       self._initialize(args, kwds, add_initializers_to=initializers)
        934     finally:
        935       # At this point we know that the initialization is complete (or less
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\eager\def_function.py in _initialize(self, args, kwds, add_initializers_to)
        758     self._concrete_stateful_fn = (
        759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    --> 760             *args, **kwds))
        761 
        762     def invalid_creator_scope(*unused_args, **unused_kwds):
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\eager\function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
       3064       args, kwargs = None, None
       3065     with self._lock:
    -> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
       3067     return graph_function
       3068 
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\eager\function.py in _maybe_define_function(self, args, kwargs)
       3461 
       3462           self._function_cache.missed.add(call_context_key)
    -> 3463           graph_function = self._create_graph_function(args, kwargs)
       3464           self._function_cache.primary[cache_key] = graph_function
       3465 
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\eager\function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
       3306             arg_names=arg_names,
       3307             override_flat_arg_shapes=override_flat_arg_shapes,
    -> 3308             capture_by_value=self._capture_by_value),
       3309         self._function_attributes,
       3310         function_spec=self.function_spec,
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\framework\func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
       1005         _, original_func = tf_decorator.unwrap(python_func)
       1006 
    -> 1007       func_outputs = python_func(*func_args, **func_kwargs)
       1008 
       1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\eager\def_function.py in wrapped_fn(*args, **kwds)
        666         # the function a weak reference to itself to avoid a reference cycle.
        667         with OptionalXlaContext(compile_with_xla):
    --> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
        669         return out
        670 
    
    ~\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\framework\func_graph.py in wrapper(*args, **kwargs)
        992           except Exception as e:  # pylint:disable=broad-except
        993             if hasattr(e, "ag_error_metadata"):
    --> 994               raise e.ag_error_metadata.to_exception(e)
        995             else:
        996               raise
    
    TypeError: in user code:
    
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\engine\training.py:862 train_function  *
            return step_function(self, iterator)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\engine\training.py:852 step_function  **
            outputs = model.distribute_strategy.run(run_step, args=(data,))
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1286 run
            return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2849 call_for_each_replica
            return self._call_for_each_replica(fn, args, kwargs)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3632 _call_for_each_replica
            return fn(*args, **kwargs)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\engine\training.py:845 run_step  **
            outputs = model.train_step(data)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\engine\training.py:807 train_step
            self.compiled_metrics.update_state(y, y_pred, sample_weight)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\engine\compile_utils.py:460 update_state
            metric_obj.update_state(y_t, y_p, sample_weight=mask)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\utils\metrics_utils.py:88 decorated
            update_op = update_state_fn(*args, **kwargs)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\metrics.py:171 update_state_fn
            return ag_update_state(*args, **kwargs)
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\metrics.py:670 update_state  **
            [y_true, y_pred], sample_weight))
        C:\Users\cocco\anaconda3\envs\mediapipe1\lib\site-packages\tensorflow\python\keras\utils\metrics_utils.py:826 ragged_assert_compatible_and_get_flat_values
            raise TypeError('One of the inputs does not have acceptable types.')
    
        TypeError: One of the inputs does not have acceptable types.
    
    

    我只找到了 this 但这并没有多大帮助。 我曾尝试有选择地删除一些输入,以查看数据是否有问题,但事实并非如此。这些数据是使用mediapipe创建的运动跟踪的地标。

    0 回复  |  直到 2 年前